diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2022-05-23 17:16:32 +0400 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2022-05-23 19:23:17 +0400 |
commit | 0110b9ea4b49d9236e537fd984d3db7f7b7a2702 (patch) | |
tree | 51d53adfefe6bb22d0741dabac8fc5c87f6f7d4e | |
parent | 6e74b9935d800034a584e3e1bc38c33904c78bdc (diff) | |
download | gitlab-shell-0110b9ea4b49d9236e537fd984d3db7f7b7a2702.tar.gz |
Close the connection when context is canceled
When graceful shutdown timeout expires, the global context is
canceled. All the operations dependent on it are canceled as well.
Unfortunately, some of the operations doesn't respect the context.
For example, SSH connection initialization.
In this case, we need to manually close the connection.
One of the options is to wait for ctx.Done() and close the connection
-rw-r--r-- | internal/sshd/sshd.go | 12 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 37 |
2 files changed, 43 insertions, 6 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index f856929..d927268 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -146,17 +146,19 @@ func (s *Server) getStatus() status { } func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { + defer s.wg.Done() + metrics.SshdConnectionsInFlight.Inc() defer metrics.SshdConnectionsInFlight.Dec() - remoteAddr := nconn.RemoteAddr().String() - - defer s.wg.Done() - defer nconn.Close() - ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) defer cancel() + go func() { + <-ctx.Done() + nconn.Close() // Close the connection when context is cancelled + }() + remoteAddr := nconn.RemoteAddr().String() ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr}) ctxlog.Debug("server: handleConn: start") diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index d725add..36adc57 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -222,6 +222,35 @@ func TestInvalidServerConfig(t *testing.T) { require.Nil(t, s.Shutdown()) } +func TestClosingHangedConnections(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := setupServerWithContext(t, nil, ctx) + + unauthenticatedRequestStatus := make(chan string) + completed := make(chan bool) + + clientCfg := clientConfig(t) + clientCfg.HostKeyCallback = func(_ string, _ net.Addr, _ ssh.PublicKey) error { + unauthenticatedRequestStatus <- "authentication-started" + <-completed // Wait infinitely + + return nil + } + + go func() { + // Start an SSH connection that never ends + ssh.Dial("tcp", serverUrl, clientCfg) + }() + + require.Equal(t, "authentication-started", <-unauthenticatedRequestStatus) + + require.NoError(t, s.Shutdown()) + cancel() + verifyStatus(t, s, StatusClosed) +} + func setupServer(t *testing.T) *Server { t.Helper() @@ -231,6 +260,12 @@ func setupServer(t *testing.T) *Server { func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server { t.Helper() + return setupServerWithContext(t, cfg, context.Background()) +} + +func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Context) *Server { + t.Helper() + requests := []testserver.TestRequestHandler{ { Path: "/api/v4/internal/authorized_keys", @@ -270,7 +305,7 @@ func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server { s, err := NewServer(cfg) require.NoError(t, err) - go func() { require.NoError(t, s.ListenAndServe(context.Background())) }() + go func() { require.NoError(t, s.ListenAndServe(ctx)) }() t.Cleanup(func() { s.Shutdown() }) verifyStatus(t, s, StatusReady) |