summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Bajao <ebajao@gitlab.com>2022-05-17 05:15:43 +0000
committerPatrick Bajao <ebajao@gitlab.com>2022-05-17 05:15:43 +0000
commit9cb22b2f1618005d3f610e25a15c82aef371d476 (patch)
tree180d82d6ae834e178174440d75d08b686d57702e
parent7cde0770f2a29010181f95eef4c1744e16f5e0d8 (diff)
parent509e04b63c9bee521b6c6536224f07fa458362d8 (diff)
downloadgitlab-shell-9cb22b2f1618005d3f610e25a15c82aef371d476.tar.gz
Merge branch 'id-wait-until-gitaly-execution' into 'main'
Wait until all Gitaly sessions are executed See merge request gitlab-org/gitlab-shell!624
-rw-r--r--internal/config/config_test.go3
-rw-r--r--internal/metrics/metrics.go10
-rw-r--r--internal/sshd/connection.go31
-rw-r--r--internal/sshd/connection_test.go45
-rw-r--r--internal/sshd/session.go39
-rw-r--r--internal/sshd/session_test.go87
-rw-r--r--internal/sshd/sshd.go8
7 files changed, 158 insertions, 65 deletions
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index a929106..32580b8 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -39,7 +39,7 @@ func TestCustomPrometheusMetrics(t *testing.T) {
require.NoError(t, err)
var actualNames []string
- for _, m := range ms[0:9] {
+ for _, m := range ms[0:10] {
actualNames = append(actualNames, m.GetName())
}
@@ -47,6 +47,7 @@ func TestCustomPrometheusMetrics(t *testing.T) {
"gitlab_shell_http_in_flight_requests",
"gitlab_shell_http_request_duration_seconds",
"gitlab_shell_http_requests_total",
+ "gitlab_shell_sshd_canceled_sessions",
"gitlab_shell_sshd_concurrent_limited_sessions_total",
"gitlab_shell_sshd_in_flight_connections",
"gitlab_shell_sshd_session_duration_seconds",
diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go
index 5fa5036..e3f335d 100644
--- a/internal/metrics/metrics.go
+++ b/internal/metrics/metrics.go
@@ -22,6 +22,7 @@ const (
sshdHitMaxSessionsName = "concurrent_limited_sessions_total"
sshdSessionDurationSecondsName = "session_duration_seconds"
sshdSessionEstablishedDurationSecondsName = "session_established_duration_seconds"
+ sshdCanceledSessionsName = "canceled_sessions"
sliSshdSessionsTotalName = "gitlab_sli:shell_sshd_sessions:total"
sliSshdSessionsErrorsTotalName = "gitlab_sli:shell_sshd_sessions:errors_total"
@@ -76,6 +77,15 @@ var (
},
)
+ SshdCanceledSessions = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: sshdSubsystem,
+ Name: sshdCanceledSessionsName,
+ Help: "The number of canceled gitlab-sshd sessions.",
+ },
+ )
+
SliSshdSessionsTotal = promauto.NewCounter(
prometheus.CounterOpts{
Name: sliSshdSessionsTotalName,
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 5b1232d..0295d8f 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -6,6 +6,8 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
+ grpccodes "google.golang.org/grpc/codes"
+ grpcstatus "google.golang.org/grpc/status"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
@@ -15,19 +17,25 @@ import (
const KeepAliveMsg = "keepalive@openssh.com"
+var EOFTimeout = 10 * time.Second
+
type connection struct {
cfg *config.Config
concurrentSessions *semaphore.Weighted
remoteAddr string
sconn *ssh.ServerConn
+ maxSessions int64
}
-type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
+type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error
func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
+ maxSessions := cfg.Server.ConcurrentSessionsLimit
+
return &connection{
cfg: cfg,
- concurrentSessions: semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit),
+ maxSessions: maxSessions,
+ concurrentSessions: semaphore.NewWeighted(maxSessions),
remoteAddr: remoteAddr,
sconn: sconn,
}
@@ -76,10 +84,27 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
}
}()
- handler(ctx, channel, requests)
+ metrics.SliSshdSessionsTotal.Inc()
+ err := handler(ctx, channel, requests)
+ if err != nil {
+ if grpcstatus.Convert(err).Code() == grpccodes.Canceled {
+ metrics.SshdCanceledSessions.Inc()
+ } else {
+ metrics.SliSshdSessionsErrorsTotal.Inc()
+ }
+ }
+
ctxlog.Info("connection: handle: done")
}()
}
+
+ // When a connection has been prematurely closed we block execution until all concurrent sessions are released
+ // in order to allow Gitaly complete the operations and close all the channels gracefully.
+ // If it didn't happen within timeout, we unblock the execution
+ // Related issue: https://gitlab.com/gitlab-org/gitlab-shell/-/issues/563
+ ctx, cancel := context.WithTimeout(ctx, EOFTimeout)
+ defer cancel()
+ c.concurrentSessions.Acquire(ctx, c.maxSessions)
}
func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) {
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go
index 3bd9bf8..f792300 100644
--- a/internal/sshd/connection_test.go
+++ b/internal/sshd/connection_test.go
@@ -7,10 +7,14 @@ import (
"testing"
"time"
+ "github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
+ grpccodes "google.golang.org/grpc/codes"
+ grpcstatus "google.golang.org/grpc/status"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
)
type rejectCall struct {
@@ -90,7 +94,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
@@ -128,8 +132,9 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
+ return nil
})
}()
@@ -142,9 +147,10 @@ func TestAcceptSessionSucceeds(t *testing.T) {
conn, chans := setup(1, newChannel)
channelHandled := false
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
+ return nil
})
require.True(t, channelHandled)
@@ -160,8 +166,9 @@ func TestAcceptSessionFails(t *testing.T) {
channelHandled := false
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
+ return nil
})
}()
@@ -186,3 +193,33 @@ func TestClientAliveInterval(t *testing.T) {
require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond)
}
+
+func TestSessionsMetrics(t *testing.T) {
+ // Unfortunately, there is no working way to reset Counter (not CounterVec)
+ // https://pkg.go.dev/github.com/prometheus/client_golang/prometheus#pkg-index
+ initialSessionsTotal := testutil.ToFloat64(metrics.SliSshdSessionsTotal)
+ initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal)
+ initialCanceledSessions := testutil.ToFloat64(metrics.SshdCanceledSessions)
+
+ newChannel := &fakeNewChannel{channelType: "session"}
+
+ conn, chans := setup(1, newChannel)
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ close(chans)
+ return errors.New("custom error")
+ })
+
+ require.InDelta(t, initialSessionsTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
+ require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
+ require.InDelta(t, initialCanceledSessions, testutil.ToFloat64(metrics.SshdCanceledSessions), 0.1)
+
+ conn, chans = setup(1, newChannel)
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ close(chans)
+ return grpcstatus.Error(grpccodes.Canceled, "error")
+ })
+
+ require.InDelta(t, initialSessionsTotal+2, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
+ require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
+ require.InDelta(t, initialCanceledSessions+1, testutil.ToFloat64(metrics.SshdCanceledSessions), 0.1)
+}
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
index beb529e..831beb8 100644
--- a/internal/sshd/session.go
+++ b/internal/sshd/session.go
@@ -22,7 +22,6 @@ type session struct {
channel ssh.Channel
gitlabKeyId string
remoteAddr string
- success bool
// State managed by the session
execCmd string
@@ -42,11 +41,12 @@ type exitStatusReq struct {
ExitStatus uint32
}
-func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
+func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error {
ctxlog := log.ContextLogger(ctx)
ctxlog.Debug("session: handle: entering request loop")
+ var err error
for req := range requests {
sessionLog := ctxlog.WithFields(log.Fields{
"bytesize": len(req.Payload),
@@ -58,12 +58,14 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
var shouldContinue bool
switch req.Type {
case "env":
- shouldContinue = s.handleEnv(ctx, req)
+ shouldContinue, err = s.handleEnv(ctx, req)
case "exec":
- shouldContinue = s.handleExec(ctx, req)
+ shouldContinue, err = s.handleExec(ctx, req)
case "shell":
shouldContinue = false
- s.exit(ctx, s.handleShell(ctx, req))
+ var status uint32
+ status, err = s.handleShell(ctx, req)
+ s.exit(ctx, status)
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
@@ -84,15 +86,17 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
}
ctxlog.Debug("session: handle: exiting request loop")
+
+ return err
}
-func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
+func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
var accepted bool
var envRequest envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
- return false
+ return false, err
}
switch envRequest.Name {
@@ -113,23 +117,24 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
).Debug("session: handleEnv: processed")
- return true
+ return true, nil
}
-func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
+func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) {
var execRequest execRequest
if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
- return false
+ return false, err
}
s.execCmd = execRequest.Command
- s.exit(ctx, s.handleShell(ctx, req))
+ status, err := s.handleShell(ctx, req)
+ s.exit(ctx, status)
- return false
+ return false, err
}
-func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
+func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) {
ctxlog := log.ContextLogger(ctx)
if req.WantReply {
@@ -157,7 +162,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
s.toStderr(ctx, "Failed to parse command: %v\n", err.Error())
}
s.toStderr(ctx, "Unknown command: %v\n", s.execCmd)
- return 128
+ return 128, err
}
cmdName := reflect.TypeOf(cmd).String()
@@ -165,12 +170,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
if err := cmd.Execute(ctx); err != nil {
s.toStderr(ctx, "remote: ERROR: %v\n", err.Error())
- return 1
+ return 1, err
}
ctxlog.Info("session: handleShell: command executed successfully")
- return 0
+ return 0, nil
}
func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
@@ -183,8 +188,6 @@ func (s *session) exit(ctx context.Context, status uint32) {
log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
- s.success = status == 0
-
s.channel.CloseWrite()
s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
}
diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go
index d0cc8d4..5bc8e7c 100644
--- a/internal/sshd/session_test.go
+++ b/internal/sshd/session_test.go
@@ -3,6 +3,7 @@ package sshd
import (
"bytes"
"context"
+ "errors"
"io"
"net/http"
"testing"
@@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) {
testCases := []struct {
desc string
payload []byte
+ expectedErr error
expectedProtocolVersion string
expectedResult bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
+ expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"),
expectedProtocolVersion: "1",
expectedResult: false,
}, {
desc: "valid payload",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
+ expectedErr: nil,
expectedProtocolVersion: "2",
expectedResult: true,
}, {
desc: "valid payload with forbidden env var",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
+ expectedErr: nil,
expectedProtocolVersion: "1",
expectedResult: true,
},
@@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) {
s := &session{gitProtocolVersion: "1"}
r := &ssh.Request{Payload: tc.payload}
- require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
- require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
+ shouldContinue, err := s.handleEnv(context.Background(), r)
+
+ require.Equal(t, tc.expectedErr, err)
+ require.Equal(t, tc.expectedResult, shouldContinue)
+ require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion)
})
}
}
@@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) {
testCases := []struct {
desc string
payload []byte
+ expectedErr error
expectedExecCmd string
sentRequestName string
sentRequestPayload []byte
- success bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
+ expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"),
expectedExecCmd: "",
sentRequestName: "",
}, {
desc: "valid payload",
payload: ssh.Marshal(execRequest{Command: "discover"}),
+ expectedErr: nil,
expectedExecCmd: "discover",
sentRequestName: "exit-status",
sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
- success: true,
},
}
@@ -129,47 +138,53 @@ func TestHandleExec(t *testing.T) {
}
r := &ssh.Request{Payload: tc.payload}
- require.Equal(t, false, s.handleExec(context.Background(), r))
+ shouldContinue, err := s.handleExec(context.Background(), r)
+
+ require.Equal(t, tc.expectedErr, err)
+ require.Equal(t, false, shouldContinue)
require.Equal(t, tc.sentRequestName, f.sentRequestName)
require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
- require.Equal(t, tc.success, s.success)
})
}
}
func TestHandleShell(t *testing.T) {
testCases := []struct {
- desc string
- cmd string
- errMsg string
- gitlabKeyId string
- expectedExitCode uint32
- success bool
+ desc string
+ cmd string
+ errMsg string
+ gitlabKeyId string
+ expectedErrString string
+ expectedExitCode uint32
}{
{
- desc: "fails to parse command",
- cmd: `\`,
- errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
- gitlabKeyId: "root",
- expectedExitCode: 128,
+ desc: "fails to parse command",
+ cmd: `\`,
+ errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
+ gitlabKeyId: "root",
+ expectedErrString: "Invalid SSH command: invalid command line string",
+ expectedExitCode: 128,
}, {
- desc: "specified command is unknown",
- cmd: "unknown-command",
- errMsg: "Unknown command: unknown-command\n",
- gitlabKeyId: "root",
- expectedExitCode: 128,
+ desc: "specified command is unknown",
+ cmd: "unknown-command",
+ errMsg: "Unknown command: unknown-command\n",
+ gitlabKeyId: "root",
+ expectedErrString: "Disallowed command",
+ expectedExitCode: 128,
}, {
- desc: "fails to parse command",
- cmd: "discover",
- gitlabKeyId: "",
- errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
- expectedExitCode: 1,
+ desc: "fails to parse command",
+ cmd: "discover",
+ gitlabKeyId: "",
+ errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
+ expectedErrString: "Failed to get username: who='' is invalid",
+ expectedExitCode: 1,
}, {
- desc: "fails to parse command",
- cmd: "discover",
- errMsg: "",
- gitlabKeyId: "root",
- expectedExitCode: 0,
+ desc: "fails to parse command",
+ cmd: "discover",
+ errMsg: "",
+ gitlabKeyId: "root",
+ expectedErrString: "",
+ expectedExitCode: 0,
},
}
@@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) {
}
r := &ssh.Request{}
- require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
+ exitCode, err := s.handleShell(context.Background(), r)
+
+ if tc.expectedErrString != "" {
+ require.Equal(t, tc.expectedErrString, err.Error())
+ }
+
+ require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.errMsg, out.String())
})
}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 242e4f2..a9cd302 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -181,7 +181,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started := time.Now()
var establishSessionDuration float64
conn := newConnection(s.Config, remoteAddr, sconn)
- conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
+ conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) error {
establishSessionDuration = time.Since(started).Seconds()
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
@@ -192,11 +192,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr: remoteAddr,
}
- metrics.SliSshdSessionsTotal.Inc()
- session.handle(ctx, requests)
- if !session.success {
- metrics.SliSshdSessionsErrorsTotal.Inc()
- }
+ return session.handle(ctx, requests)
})
reason := sconn.Wait()