diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2022-05-16 01:17:52 +0400 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2022-05-16 12:05:32 +0400 |
commit | a77babe96fac9c880061fa63fffabfc8406f11bf (patch) | |
tree | 8ee023b0bc368fec094e57301875535b14ac20ec /internal/sshd | |
parent | 7cde0770f2a29010181f95eef4c1744e16f5e0d8 (diff) | |
download | gitlab-shell-a77babe96fac9c880061fa63fffabfc8406f11bf.tar.gz |
Return error from session handler
Diffstat (limited to 'internal/sshd')
-rw-r--r-- | internal/sshd/connection.go | 9 | ||||
-rw-r--r-- | internal/sshd/connection_test.go | 16 | ||||
-rw-r--r-- | internal/sshd/session.go | 39 | ||||
-rw-r--r-- | internal/sshd/session_test.go | 87 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 8 |
5 files changed, 96 insertions, 63 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 5b1232d..060156d 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -22,7 +22,7 @@ type connection struct { sconn *ssh.ServerConn } -type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) +type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection { return &connection{ @@ -76,7 +76,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha } }() - handler(ctx, channel, requests) + metrics.SliSshdSessionsTotal.Inc() + err := handler(ctx, channel, requests) + if err != nil { + metrics.SliSshdSessionsErrorsTotal.Inc() + } + ctxlog.Info("connection: handle: done") }() } diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index 3bd9bf8..0a06255 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -55,9 +55,14 @@ type fakeConn struct { ssh.Conn sentRequestName string + waitErr error mu sync.Mutex } +func (f *fakeConn) Wait() error { + return f.waitErr +} + func (f *fakeConn) SentRequestName() string { f.mu.Lock() defer f.mu.Unlock() @@ -90,7 +95,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) { numSessions := 0 require.NotPanics(t, func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { numSessions += 1 close(chans) panic("This is a panic") @@ -128,8 +133,9 @@ func TestTooManySessions(t *testing.T) { defer cancel() go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { <-ctx.Done() // Keep the accepted channel open until the end of the test + return nil }) }() @@ -142,9 +148,10 @@ func TestAcceptSessionSucceeds(t *testing.T) { conn, chans := setup(1, newChannel) channelHandled := false - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true close(chans) + return nil }) require.True(t, channelHandled) @@ -160,8 +167,9 @@ func TestAcceptSessionFails(t *testing.T) { channelHandled := false go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true + return nil }) }() diff --git a/internal/sshd/session.go b/internal/sshd/session.go index beb529e..831beb8 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -22,7 +22,6 @@ type session struct { channel ssh.Channel gitlabKeyId string remoteAddr string - success bool // State managed by the session execCmd string @@ -42,11 +41,12 @@ type exitStatusReq struct { ExitStatus uint32 } -func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { +func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error { ctxlog := log.ContextLogger(ctx) ctxlog.Debug("session: handle: entering request loop") + var err error for req := range requests { sessionLog := ctxlog.WithFields(log.Fields{ "bytesize": len(req.Payload), @@ -58,12 +58,14 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { var shouldContinue bool switch req.Type { case "env": - shouldContinue = s.handleEnv(ctx, req) + shouldContinue, err = s.handleEnv(ctx, req) case "exec": - shouldContinue = s.handleExec(ctx, req) + shouldContinue, err = s.handleExec(ctx, req) case "shell": shouldContinue = false - s.exit(ctx, s.handleShell(ctx, req)) + var status uint32 + status, err = s.handleShell(ctx, req) + s.exit(ctx, status) default: // Ignore unknown requests but don't terminate the session shouldContinue = true @@ -84,15 +86,17 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { } ctxlog.Debug("session: handle: exiting request loop") + + return err } -func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { +func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) { var accepted bool var envRequest envRequest if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil { log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request") - return false + return false, err } switch envRequest.Name { @@ -113,23 +117,24 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { ctx, log.Fields{"accepted": accepted, "env_request": envRequest}, ).Debug("session: handleEnv: processed") - return true + return true, nil } -func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool { +func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) { var execRequest execRequest if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil { - return false + return false, err } s.execCmd = execRequest.Command - s.exit(ctx, s.handleShell(ctx, req)) + status, err := s.handleShell(ctx, req) + s.exit(ctx, status) - return false + return false, err } -func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { +func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) { ctxlog := log.ContextLogger(ctx) if req.WantReply { @@ -157,7 +162,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { s.toStderr(ctx, "Failed to parse command: %v\n", err.Error()) } s.toStderr(ctx, "Unknown command: %v\n", s.execCmd) - return 128 + return 128, err } cmdName := reflect.TypeOf(cmd).String() @@ -165,12 +170,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { if err := cmd.Execute(ctx); err != nil { s.toStderr(ctx, "remote: ERROR: %v\n", err.Error()) - return 1 + return 1, err } ctxlog.Info("session: handleShell: command executed successfully") - return 0 + return 0, nil } func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { @@ -183,8 +188,6 @@ func (s *session) exit(ctx context.Context, status uint32) { log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting") req := exitStatusReq{ExitStatus: status} - s.success = status == 0 - s.channel.CloseWrite() s.channel.SendRequest("exit-status", false, ssh.Marshal(req)) } diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go index d0cc8d4..5bc8e7c 100644 --- a/internal/sshd/session_test.go +++ b/internal/sshd/session_test.go @@ -3,6 +3,7 @@ package sshd import ( "bytes" "context" + "errors" "io" "net/http" "testing" @@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) { testCases := []struct { desc string payload []byte + expectedErr error expectedProtocolVersion string expectedResult bool }{ { desc: "invalid payload", payload: []byte("invalid"), + expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"), expectedProtocolVersion: "1", expectedResult: false, }, { desc: "valid payload", payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}), + expectedErr: nil, expectedProtocolVersion: "2", expectedResult: true, }, { desc: "valid payload with forbidden env var", payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}), + expectedErr: nil, expectedProtocolVersion: "1", expectedResult: true, }, @@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) { s := &session{gitProtocolVersion: "1"} r := &ssh.Request{Payload: tc.payload} - require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult) - require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion) + shouldContinue, err := s.handleEnv(context.Background(), r) + + require.Equal(t, tc.expectedErr, err) + require.Equal(t, tc.expectedResult, shouldContinue) + require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion) }) } } @@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) { testCases := []struct { desc string payload []byte + expectedErr error expectedExecCmd string sentRequestName string sentRequestPayload []byte - success bool }{ { desc: "invalid payload", payload: []byte("invalid"), + expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"), expectedExecCmd: "", sentRequestName: "", }, { desc: "valid payload", payload: ssh.Marshal(execRequest{Command: "discover"}), + expectedErr: nil, expectedExecCmd: "discover", sentRequestName: "exit-status", sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}), - success: true, }, } @@ -129,47 +138,53 @@ func TestHandleExec(t *testing.T) { } r := &ssh.Request{Payload: tc.payload} - require.Equal(t, false, s.handleExec(context.Background(), r)) + shouldContinue, err := s.handleExec(context.Background(), r) + + require.Equal(t, tc.expectedErr, err) + require.Equal(t, false, shouldContinue) require.Equal(t, tc.sentRequestName, f.sentRequestName) require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload) - require.Equal(t, tc.success, s.success) }) } } func TestHandleShell(t *testing.T) { testCases := []struct { - desc string - cmd string - errMsg string - gitlabKeyId string - expectedExitCode uint32 - success bool + desc string + cmd string + errMsg string + gitlabKeyId string + expectedErrString string + expectedExitCode uint32 }{ { - desc: "fails to parse command", - cmd: `\`, - errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n", - gitlabKeyId: "root", - expectedExitCode: 128, + desc: "fails to parse command", + cmd: `\`, + errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n", + gitlabKeyId: "root", + expectedErrString: "Invalid SSH command: invalid command line string", + expectedExitCode: 128, }, { - desc: "specified command is unknown", - cmd: "unknown-command", - errMsg: "Unknown command: unknown-command\n", - gitlabKeyId: "root", - expectedExitCode: 128, + desc: "specified command is unknown", + cmd: "unknown-command", + errMsg: "Unknown command: unknown-command\n", + gitlabKeyId: "root", + expectedErrString: "Disallowed command", + expectedExitCode: 128, }, { - desc: "fails to parse command", - cmd: "discover", - gitlabKeyId: "", - errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n", - expectedExitCode: 1, + desc: "fails to parse command", + cmd: "discover", + gitlabKeyId: "", + errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n", + expectedErrString: "Failed to get username: who='' is invalid", + expectedExitCode: 1, }, { - desc: "fails to parse command", - cmd: "discover", - errMsg: "", - gitlabKeyId: "root", - expectedExitCode: 0, + desc: "fails to parse command", + cmd: "discover", + errMsg: "", + gitlabKeyId: "root", + expectedErrString: "", + expectedExitCode: 0, }, } @@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) { } r := &ssh.Request{} - require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r)) + exitCode, err := s.handleShell(context.Background(), r) + + if tc.expectedErrString != "" { + require.Equal(t, tc.expectedErrString, err.Error()) + } + + require.Equal(t, tc.expectedExitCode, exitCode) require.Equal(t, tc.errMsg, out.String()) }) } diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 242e4f2..a9cd302 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -181,7 +181,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { started := time.Now() var establishSessionDuration float64 conn := newConnection(s.Config, remoteAddr, sconn) - conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) { + conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) error { establishSessionDuration = time.Since(started).Seconds() metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) @@ -192,11 +192,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { remoteAddr: remoteAddr, } - metrics.SliSshdSessionsTotal.Inc() - session.handle(ctx, requests) - if !session.success { - metrics.SliSshdSessionsErrorsTotal.Inc() - } + return session.handle(ctx, requests) }) reason := sconn.Wait() |