summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor Drozdov <idrozdov@gitlab.com>2022-05-23 17:16:32 +0400
committerIgor Drozdov <idrozdov@gitlab.com>2022-05-23 19:23:17 +0400
commit0110b9ea4b49d9236e537fd984d3db7f7b7a2702 (patch)
tree51d53adfefe6bb22d0741dabac8fc5c87f6f7d4e
parent6e74b9935d800034a584e3e1bc38c33904c78bdc (diff)
downloadgitlab-shell-0110b9ea4b49d9236e537fd984d3db7f7b7a2702.tar.gz
Close the connection when context is canceled
When graceful shutdown timeout expires, the global context is canceled. All the operations dependent on it are canceled as well. Unfortunately, some of the operations doesn't respect the context. For example, SSH connection initialization. In this case, we need to manually close the connection. One of the options is to wait for ctx.Done() and close the connection
-rw-r--r--internal/sshd/sshd.go12
-rw-r--r--internal/sshd/sshd_test.go37
2 files changed, 43 insertions, 6 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index f856929..d927268 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -146,17 +146,19 @@ func (s *Server) getStatus() status {
}
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
+ defer s.wg.Done()
+
metrics.SshdConnectionsInFlight.Inc()
defer metrics.SshdConnectionsInFlight.Dec()
- remoteAddr := nconn.RemoteAddr().String()
-
- defer s.wg.Done()
- defer nconn.Close()
-
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
+ go func() {
+ <-ctx.Done()
+ nconn.Close() // Close the connection when context is cancelled
+ }()
+ remoteAddr := nconn.RemoteAddr().String()
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})
ctxlog.Debug("server: handleConn: start")
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index d725add..36adc57 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -222,6 +222,35 @@ func TestInvalidServerConfig(t *testing.T) {
require.Nil(t, s.Shutdown())
}
+func TestClosingHangedConnections(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ s := setupServerWithContext(t, nil, ctx)
+
+ unauthenticatedRequestStatus := make(chan string)
+ completed := make(chan bool)
+
+ clientCfg := clientConfig(t)
+ clientCfg.HostKeyCallback = func(_ string, _ net.Addr, _ ssh.PublicKey) error {
+ unauthenticatedRequestStatus <- "authentication-started"
+ <-completed // Wait infinitely
+
+ return nil
+ }
+
+ go func() {
+ // Start an SSH connection that never ends
+ ssh.Dial("tcp", serverUrl, clientCfg)
+ }()
+
+ require.Equal(t, "authentication-started", <-unauthenticatedRequestStatus)
+
+ require.NoError(t, s.Shutdown())
+ cancel()
+ verifyStatus(t, s, StatusClosed)
+}
+
func setupServer(t *testing.T) *Server {
t.Helper()
@@ -231,6 +260,12 @@ func setupServer(t *testing.T) *Server {
func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
t.Helper()
+ return setupServerWithContext(t, cfg, context.Background())
+}
+
+func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Context) *Server {
+ t.Helper()
+
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/authorized_keys",
@@ -270,7 +305,7 @@ func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
s, err := NewServer(cfg)
require.NoError(t, err)
- go func() { require.NoError(t, s.ListenAndServe(context.Background())) }()
+ go func() { require.NoError(t, s.ListenAndServe(ctx)) }()
t.Cleanup(func() { s.Shutdown() })
verifyStatus(t, s, StatusReady)