From 0110b9ea4b49d9236e537fd984d3db7f7b7a2702 Mon Sep 17 00:00:00 2001 From: Igor Drozdov Date: Mon, 23 May 2022 17:16:32 +0400 Subject: Close the connection when context is canceled When graceful shutdown timeout expires, the global context is canceled. All the operations dependent on it are canceled as well. Unfortunately, some of the operations doesn't respect the context. For example, SSH connection initialization. In this case, we need to manually close the connection. One of the options is to wait for ctx.Done() and close the connection --- internal/sshd/sshd.go | 12 +++++++----- internal/sshd/sshd_test.go | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index f856929..d927268 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -146,17 +146,19 @@ func (s *Server) getStatus() status { } func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { + defer s.wg.Done() + metrics.SshdConnectionsInFlight.Inc() defer metrics.SshdConnectionsInFlight.Dec() - remoteAddr := nconn.RemoteAddr().String() - - defer s.wg.Done() - defer nconn.Close() - ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) defer cancel() + go func() { + <-ctx.Done() + nconn.Close() // Close the connection when context is cancelled + }() + remoteAddr := nconn.RemoteAddr().String() ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr}) ctxlog.Debug("server: handleConn: start") diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index d725add..36adc57 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -222,6 +222,35 @@ func TestInvalidServerConfig(t *testing.T) { require.Nil(t, s.Shutdown()) } +func TestClosingHangedConnections(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := setupServerWithContext(t, nil, ctx) + + unauthenticatedRequestStatus := make(chan string) + completed := make(chan bool) + + clientCfg := clientConfig(t) + clientCfg.HostKeyCallback = func(_ string, _ net.Addr, _ ssh.PublicKey) error { + unauthenticatedRequestStatus <- "authentication-started" + <-completed // Wait infinitely + + return nil + } + + go func() { + // Start an SSH connection that never ends + ssh.Dial("tcp", serverUrl, clientCfg) + }() + + require.Equal(t, "authentication-started", <-unauthenticatedRequestStatus) + + require.NoError(t, s.Shutdown()) + cancel() + verifyStatus(t, s, StatusClosed) +} + func setupServer(t *testing.T) *Server { t.Helper() @@ -231,6 +260,12 @@ func setupServer(t *testing.T) *Server { func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server { t.Helper() + return setupServerWithContext(t, cfg, context.Background()) +} + +func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Context) *Server { + t.Helper() + requests := []testserver.TestRequestHandler{ { Path: "/api/v4/internal/authorized_keys", @@ -270,7 +305,7 @@ func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server { s, err := NewServer(cfg) require.NoError(t, err) - go func() { require.NoError(t, s.ListenAndServe(context.Background())) }() + go func() { require.NoError(t, s.ListenAndServe(ctx)) }() t.Cleanup(func() { s.Shutdown() }) verifyStatus(t, s, StatusReady) -- cgit v1.2.1