diff options
author | Patrick Bajao <ebajao@gitlab.com> | 2022-05-17 05:15:43 +0000 |
---|---|---|
committer | Patrick Bajao <ebajao@gitlab.com> | 2022-05-17 05:15:43 +0000 |
commit | 9cb22b2f1618005d3f610e25a15c82aef371d476 (patch) | |
tree | 180d82d6ae834e178174440d75d08b686d57702e | |
parent | 7cde0770f2a29010181f95eef4c1744e16f5e0d8 (diff) | |
parent | 509e04b63c9bee521b6c6536224f07fa458362d8 (diff) | |
download | gitlab-shell-9cb22b2f1618005d3f610e25a15c82aef371d476.tar.gz |
Merge branch 'id-wait-until-gitaly-execution' into 'main'
Wait until all Gitaly sessions are executed
See merge request gitlab-org/gitlab-shell!624
-rw-r--r-- | internal/config/config_test.go | 3 | ||||
-rw-r--r-- | internal/metrics/metrics.go | 10 | ||||
-rw-r--r-- | internal/sshd/connection.go | 31 | ||||
-rw-r--r-- | internal/sshd/connection_test.go | 45 | ||||
-rw-r--r-- | internal/sshd/session.go | 39 | ||||
-rw-r--r-- | internal/sshd/session_test.go | 87 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 8 |
7 files changed, 158 insertions, 65 deletions
diff --git a/internal/config/config_test.go b/internal/config/config_test.go index a929106..32580b8 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -39,7 +39,7 @@ func TestCustomPrometheusMetrics(t *testing.T) { require.NoError(t, err) var actualNames []string - for _, m := range ms[0:9] { + for _, m := range ms[0:10] { actualNames = append(actualNames, m.GetName()) } @@ -47,6 +47,7 @@ func TestCustomPrometheusMetrics(t *testing.T) { "gitlab_shell_http_in_flight_requests", "gitlab_shell_http_request_duration_seconds", "gitlab_shell_http_requests_total", + "gitlab_shell_sshd_canceled_sessions", "gitlab_shell_sshd_concurrent_limited_sessions_total", "gitlab_shell_sshd_in_flight_connections", "gitlab_shell_sshd_session_duration_seconds", diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 5fa5036..e3f335d 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -22,6 +22,7 @@ const ( sshdHitMaxSessionsName = "concurrent_limited_sessions_total" sshdSessionDurationSecondsName = "session_duration_seconds" sshdSessionEstablishedDurationSecondsName = "session_established_duration_seconds" + sshdCanceledSessionsName = "canceled_sessions" sliSshdSessionsTotalName = "gitlab_sli:shell_sshd_sessions:total" sliSshdSessionsErrorsTotalName = "gitlab_sli:shell_sshd_sessions:errors_total" @@ -76,6 +77,15 @@ var ( }, ) + SshdCanceledSessions = promauto.NewCounter( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: sshdSubsystem, + Name: sshdCanceledSessionsName, + Help: "The number of canceled gitlab-sshd sessions.", + }, + ) + SliSshdSessionsTotal = promauto.NewCounter( prometheus.CounterOpts{ Name: sliSshdSessionsTotalName, diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 5b1232d..0295d8f 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -6,6 +6,8 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/sync/semaphore" + grpccodes "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/metrics" @@ -15,19 +17,25 @@ import ( const KeepAliveMsg = "keepalive@openssh.com" +var EOFTimeout = 10 * time.Second + type connection struct { cfg *config.Config concurrentSessions *semaphore.Weighted remoteAddr string sconn *ssh.ServerConn + maxSessions int64 } -type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) +type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection { + maxSessions := cfg.Server.ConcurrentSessionsLimit + return &connection{ cfg: cfg, - concurrentSessions: semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit), + maxSessions: maxSessions, + concurrentSessions: semaphore.NewWeighted(maxSessions), remoteAddr: remoteAddr, sconn: sconn, } @@ -76,10 +84,27 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha } }() - handler(ctx, channel, requests) + metrics.SliSshdSessionsTotal.Inc() + err := handler(ctx, channel, requests) + if err != nil { + if grpcstatus.Convert(err).Code() == grpccodes.Canceled { + metrics.SshdCanceledSessions.Inc() + } else { + metrics.SliSshdSessionsErrorsTotal.Inc() + } + } + ctxlog.Info("connection: handle: done") }() } + + // When a connection has been prematurely closed we block execution until all concurrent sessions are released + // in order to allow Gitaly complete the operations and close all the channels gracefully. + // If it didn't happen within timeout, we unblock the execution + // Related issue: https://gitlab.com/gitlab-org/gitlab-shell/-/issues/563 + ctx, cancel := context.WithTimeout(ctx, EOFTimeout) + defer cancel() + c.concurrentSessions.Acquire(ctx, c.maxSessions) } func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) { diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index 3bd9bf8..f792300 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -7,10 +7,14 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + grpccodes "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/metrics" ) type rejectCall struct { @@ -90,7 +94,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) { numSessions := 0 require.NotPanics(t, func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { numSessions += 1 close(chans) panic("This is a panic") @@ -128,8 +132,9 @@ func TestTooManySessions(t *testing.T) { defer cancel() go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { <-ctx.Done() // Keep the accepted channel open until the end of the test + return nil }) }() @@ -142,9 +147,10 @@ func TestAcceptSessionSucceeds(t *testing.T) { conn, chans := setup(1, newChannel) channelHandled := false - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true close(chans) + return nil }) require.True(t, channelHandled) @@ -160,8 +166,9 @@ func TestAcceptSessionFails(t *testing.T) { channelHandled := false go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true + return nil }) }() @@ -186,3 +193,33 @@ func TestClientAliveInterval(t *testing.T) { require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond) } + +func TestSessionsMetrics(t *testing.T) { + // Unfortunately, there is no working way to reset Counter (not CounterVec) + // https://pkg.go.dev/github.com/prometheus/client_golang/prometheus#pkg-index + initialSessionsTotal := testutil.ToFloat64(metrics.SliSshdSessionsTotal) + initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal) + initialCanceledSessions := testutil.ToFloat64(metrics.SshdCanceledSessions) + + newChannel := &fakeNewChannel{channelType: "session"} + + conn, chans := setup(1, newChannel) + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + close(chans) + return errors.New("custom error") + }) + + require.InDelta(t, initialSessionsTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1) + require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1) + require.InDelta(t, initialCanceledSessions, testutil.ToFloat64(metrics.SshdCanceledSessions), 0.1) + + conn, chans = setup(1, newChannel) + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + close(chans) + return grpcstatus.Error(grpccodes.Canceled, "error") + }) + + require.InDelta(t, initialSessionsTotal+2, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1) + require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1) + require.InDelta(t, initialCanceledSessions+1, testutil.ToFloat64(metrics.SshdCanceledSessions), 0.1) +} diff --git a/internal/sshd/session.go b/internal/sshd/session.go index beb529e..831beb8 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -22,7 +22,6 @@ type session struct { channel ssh.Channel gitlabKeyId string remoteAddr string - success bool // State managed by the session execCmd string @@ -42,11 +41,12 @@ type exitStatusReq struct { ExitStatus uint32 } -func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { +func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error { ctxlog := log.ContextLogger(ctx) ctxlog.Debug("session: handle: entering request loop") + var err error for req := range requests { sessionLog := ctxlog.WithFields(log.Fields{ "bytesize": len(req.Payload), @@ -58,12 +58,14 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { var shouldContinue bool switch req.Type { case "env": - shouldContinue = s.handleEnv(ctx, req) + shouldContinue, err = s.handleEnv(ctx, req) case "exec": - shouldContinue = s.handleExec(ctx, req) + shouldContinue, err = s.handleExec(ctx, req) case "shell": shouldContinue = false - s.exit(ctx, s.handleShell(ctx, req)) + var status uint32 + status, err = s.handleShell(ctx, req) + s.exit(ctx, status) default: // Ignore unknown requests but don't terminate the session shouldContinue = true @@ -84,15 +86,17 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { } ctxlog.Debug("session: handle: exiting request loop") + + return err } -func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { +func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) { var accepted bool var envRequest envRequest if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil { log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request") - return false + return false, err } switch envRequest.Name { @@ -113,23 +117,24 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { ctx, log.Fields{"accepted": accepted, "env_request": envRequest}, ).Debug("session: handleEnv: processed") - return true + return true, nil } -func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool { +func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) { var execRequest execRequest if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil { - return false + return false, err } s.execCmd = execRequest.Command - s.exit(ctx, s.handleShell(ctx, req)) + status, err := s.handleShell(ctx, req) + s.exit(ctx, status) - return false + return false, err } -func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { +func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) { ctxlog := log.ContextLogger(ctx) if req.WantReply { @@ -157,7 +162,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { s.toStderr(ctx, "Failed to parse command: %v\n", err.Error()) } s.toStderr(ctx, "Unknown command: %v\n", s.execCmd) - return 128 + return 128, err } cmdName := reflect.TypeOf(cmd).String() @@ -165,12 +170,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { if err := cmd.Execute(ctx); err != nil { s.toStderr(ctx, "remote: ERROR: %v\n", err.Error()) - return 1 + return 1, err } ctxlog.Info("session: handleShell: command executed successfully") - return 0 + return 0, nil } func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { @@ -183,8 +188,6 @@ func (s *session) exit(ctx context.Context, status uint32) { log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting") req := exitStatusReq{ExitStatus: status} - s.success = status == 0 - s.channel.CloseWrite() s.channel.SendRequest("exit-status", false, ssh.Marshal(req)) } diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go index d0cc8d4..5bc8e7c 100644 --- a/internal/sshd/session_test.go +++ b/internal/sshd/session_test.go @@ -3,6 +3,7 @@ package sshd import ( "bytes" "context" + "errors" "io" "net/http" "testing" @@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) { testCases := []struct { desc string payload []byte + expectedErr error expectedProtocolVersion string expectedResult bool }{ { desc: "invalid payload", payload: []byte("invalid"), + expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"), expectedProtocolVersion: "1", expectedResult: false, }, { desc: "valid payload", payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}), + expectedErr: nil, expectedProtocolVersion: "2", expectedResult: true, }, { desc: "valid payload with forbidden env var", payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}), + expectedErr: nil, expectedProtocolVersion: "1", expectedResult: true, }, @@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) { s := &session{gitProtocolVersion: "1"} r := &ssh.Request{Payload: tc.payload} - require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult) - require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion) + shouldContinue, err := s.handleEnv(context.Background(), r) + + require.Equal(t, tc.expectedErr, err) + require.Equal(t, tc.expectedResult, shouldContinue) + require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion) }) } } @@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) { testCases := []struct { desc string payload []byte + expectedErr error expectedExecCmd string sentRequestName string sentRequestPayload []byte - success bool }{ { desc: "invalid payload", payload: []byte("invalid"), + expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"), expectedExecCmd: "", sentRequestName: "", }, { desc: "valid payload", payload: ssh.Marshal(execRequest{Command: "discover"}), + expectedErr: nil, expectedExecCmd: "discover", sentRequestName: "exit-status", sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}), - success: true, }, } @@ -129,47 +138,53 @@ func TestHandleExec(t *testing.T) { } r := &ssh.Request{Payload: tc.payload} - require.Equal(t, false, s.handleExec(context.Background(), r)) + shouldContinue, err := s.handleExec(context.Background(), r) + + require.Equal(t, tc.expectedErr, err) + require.Equal(t, false, shouldContinue) require.Equal(t, tc.sentRequestName, f.sentRequestName) require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload) - require.Equal(t, tc.success, s.success) }) } } func TestHandleShell(t *testing.T) { testCases := []struct { - desc string - cmd string - errMsg string - gitlabKeyId string - expectedExitCode uint32 - success bool + desc string + cmd string + errMsg string + gitlabKeyId string + expectedErrString string + expectedExitCode uint32 }{ { - desc: "fails to parse command", - cmd: `\`, - errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n", - gitlabKeyId: "root", - expectedExitCode: 128, + desc: "fails to parse command", + cmd: `\`, + errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n", + gitlabKeyId: "root", + expectedErrString: "Invalid SSH command: invalid command line string", + expectedExitCode: 128, }, { - desc: "specified command is unknown", - cmd: "unknown-command", - errMsg: "Unknown command: unknown-command\n", - gitlabKeyId: "root", - expectedExitCode: 128, + desc: "specified command is unknown", + cmd: "unknown-command", + errMsg: "Unknown command: unknown-command\n", + gitlabKeyId: "root", + expectedErrString: "Disallowed command", + expectedExitCode: 128, }, { - desc: "fails to parse command", - cmd: "discover", - gitlabKeyId: "", - errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n", - expectedExitCode: 1, + desc: "fails to parse command", + cmd: "discover", + gitlabKeyId: "", + errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n", + expectedErrString: "Failed to get username: who='' is invalid", + expectedExitCode: 1, }, { - desc: "fails to parse command", - cmd: "discover", - errMsg: "", - gitlabKeyId: "root", - expectedExitCode: 0, + desc: "fails to parse command", + cmd: "discover", + errMsg: "", + gitlabKeyId: "root", + expectedErrString: "", + expectedExitCode: 0, }, } @@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) { } r := &ssh.Request{} - require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r)) + exitCode, err := s.handleShell(context.Background(), r) + + if tc.expectedErrString != "" { + require.Equal(t, tc.expectedErrString, err.Error()) + } + + require.Equal(t, tc.expectedExitCode, exitCode) require.Equal(t, tc.errMsg, out.String()) }) } diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 242e4f2..a9cd302 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -181,7 +181,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { started := time.Now() var establishSessionDuration float64 conn := newConnection(s.Config, remoteAddr, sconn) - conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) { + conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) error { establishSessionDuration = time.Since(started).Seconds() metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) @@ -192,11 +192,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { remoteAddr: remoteAddr, } - metrics.SliSshdSessionsTotal.Inc() - session.handle(ctx, requests) - if !session.success { - metrics.SliSshdSessionsErrorsTotal.Inc() - } + return session.handle(ctx, requests) }) reason := sconn.Wait() |