diff options
-rw-r--r-- | internal/sshd/connection.go | 82 | ||||
-rw-r--r-- | internal/sshd/connection_test.go | 26 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 39 |
3 files changed, 78 insertions, 69 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 61234a3..eaae5ca 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -3,6 +3,8 @@ package sshd import ( "context" "errors" + "net" + "strings" "time" "golang.org/x/crypto/ssh" @@ -22,52 +24,91 @@ const KeepAliveMsg = "keepalive@openssh.com" var EOFTimeout = 10 * time.Second type connection struct { - cfg *config.Config - concurrentSessions *semaphore.Weighted - remoteAddr string - sconn *ssh.ServerConn - maxSessions int64 + cfg *config.Config + concurrentSessions *semaphore.Weighted + nconn net.Conn + maxSessions int64 + remoteAddr string + started time.Time + establishSessionDuration float64 } -type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error +type channelHandler func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error -func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection { +func newConnection(cfg *config.Config, nconn net.Conn) *connection { maxSessions := cfg.Server.ConcurrentSessionsLimit return &connection{ cfg: cfg, maxSessions: maxSessions, concurrentSessions: semaphore.NewWeighted(maxSessions), - remoteAddr: remoteAddr, - sconn: sconn, + nconn: nconn, + remoteAddr: nconn.RemoteAddr().String(), + started: time.Now(), } } -func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { - ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) +func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) { + sconn, chans, err := c.initServerConn(ctx, srvCfg) + if err != nil { + return + } if c.cfg.Server.ClientAliveInterval > 0 { ticker := time.NewTicker(time.Duration(c.cfg.Server.ClientAliveInterval)) defer ticker.Stop() - go c.sendKeepAliveMsg(ctx, ticker) + go c.sendKeepAliveMsg(ctx, sconn, ticker) + } + + c.handleRequests(ctx, sconn, chans, handler) + + reason := sconn.Wait() + log.WithContextFields(ctx, log.Fields{ + "duration_s": time.Since(c.started).Seconds(), + "establish_session_duration_s": c.establishSessionDuration, + "reason": reason, + }).Info("server: handleConn: done") +} + +func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) { + sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, srvCfg) + if err != nil { + msg := "connection: initServerConn: failed to initialize SSH connection" + + logger := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}).WithError(err) + + if strings.Contains(err.Error(), "no common algorithm for host key") || err.Error() == "EOF" { + logger.Debug(msg) + } else { + logger.Warn(msg) + } + + return nil, nil, err } + go ssh.DiscardRequests(reqs) + + return sconn, chans, err +} + +func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) { + ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested") if newChannel.ChannelType() != "session" { - ctxlog.Info("connection: handle: unknown channel type") + ctxlog.Info("connection: handleRequests: unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } if !c.concurrentSessions.TryAcquire(1) { - ctxlog.Info("connection: handle: too many concurrent sessions") + ctxlog.Info("connection: handleRequests: too many concurrent sessions") newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") metrics.SshdHitMaxSessions.Inc() continue } channel, requests, err := newChannel.Accept() if err != nil { - ctxlog.WithError(err).Error("connection: handle: accepting channel failed") + ctxlog.WithError(err).Error("connection: handleRequests: accepting channel failed") c.concurrentSessions.Release(1) continue } @@ -76,6 +117,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha defer func(started time.Time) { metrics.SshdSessionDuration.Observe(time.Since(started).Seconds()) }(time.Now()) + c.establishSessionDuration = time.Since(c.started).Seconds() defer c.concurrentSessions.Release(1) @@ -87,12 +129,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha }() metrics.SliSshdSessionsTotal.Inc() - err := handler(ctx, channel, requests) + err := handler(sconn, channel, requests) if err != nil { c.trackError(err) } - ctxlog.Info("connection: handle: done") + ctxlog.Info("connection: handleRequests: done") }() } @@ -105,7 +147,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha c.concurrentSessions.Acquire(ctx, c.maxSessions) } -func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) { +func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) { ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for { @@ -113,9 +155,9 @@ func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) case <-ctx.Done(): return case <-ticker.C: - ctxlog.Debug("session: handleShell: send keepalive message to a client") + ctxlog.Debug("connection: sendKeepAliveMsg: send keepalive message to a client") - c.sconn.SendRequest(KeepAliveMsg, true, nil) + sconn.SendRequest(KeepAliveMsg, true, nil) } } } diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index a6dad8d..a5225b2 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -10,6 +10,7 @@ import ( "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + "golang.org/x/sync/semaphore" grpccodes "google.golang.org/grpc/codes" grpcstatus "google.golang.org/grpc/status" @@ -81,7 +82,7 @@ func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (boo func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) { cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum}} - conn := newConnection(cfg, "127.0.0.1:50000", &ssh.ServerConn{&fakeConn{}, nil}) + conn := &connection{cfg: cfg, concurrentSessions: semaphore.NewWeighted(sessionsNum)} chans := make(chan ssh.NewChannel, 1) chans <- newChannel @@ -95,7 +96,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) { numSessions := 0 require.NotPanics(t, func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { numSessions += 1 close(chans) panic("This is a panic") @@ -113,7 +114,7 @@ func TestUnknownChannelType(t *testing.T) { conn, chans := setup(1, newChannel) go func() { - conn.handle(context.Background(), chans, nil) + conn.handleRequests(context.Background(), nil, chans, nil) }() rejectionData := <-rejectCh @@ -133,7 +134,7 @@ func TestTooManySessions(t *testing.T) { defer cancel() go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { <-ctx.Done() // Keep the accepted channel open until the end of the test return nil }) @@ -148,7 +149,7 @@ func TestAcceptSessionSucceeds(t *testing.T) { conn, chans := setup(1, newChannel) channelHandled := false - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true close(chans) return nil @@ -167,7 +168,7 @@ func TestAcceptSessionFails(t *testing.T) { channelHandled := false go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true return nil }) @@ -185,12 +186,11 @@ func TestAcceptSessionFails(t *testing.T) { func TestClientAliveInterval(t *testing.T) { f := &fakeConn{} - conn := newConnection(&config.Config{}, "127.0.0.1:50000", &ssh.ServerConn{f, nil}) - ticker := time.NewTicker(time.Millisecond) defer ticker.Stop() - go conn.sendKeepAliveMsg(context.Background(), ticker) + conn := &connection{} + go conn.sendKeepAliveMsg(context.Background(), &ssh.ServerConn{f, nil}, ticker) require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond) } @@ -204,7 +204,7 @@ func TestSessionsMetrics(t *testing.T) { newChannel := &fakeNewChannel{channelType: "session"} conn, chans := setup(1, newChannel) - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { close(chans) return errors.New("custom error") }) @@ -213,7 +213,7 @@ func TestSessionsMetrics(t *testing.T) { require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1) conn, chans = setup(1, newChannel) - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { close(chans) return grpcstatus.Error(grpccodes.Canceled, "canceled") }) @@ -222,7 +222,7 @@ func TestSessionsMetrics(t *testing.T) { require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1) conn, chans = setup(1, newChannel) - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { close(chans) return &client.ApiError{"api error"} }) @@ -231,7 +231,7 @@ func TestSessionsMetrics(t *testing.T) { require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1) conn, chans = setup(1, newChannel) - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { close(chans) return grpcstatus.Error(grpccodes.Unavailable, "unavailable") }) diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index dbb8709..f856929 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -10,7 +10,6 @@ import ( "time" "github.com/pires/go-proxyproto" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/internal/config" @@ -39,18 +38,6 @@ type Server struct { serverConfig *serverConfig } -func logSSHInitError(ctxlog *logrus.Entry, err error) { - msg := "server: handleConn: failed to initialize SSH connection" - - logger := ctxlog.WithError(err) - - if strings.Contains(err.Error(), "no common algorithm for host key") || err.Error() == "EOF" { - logger.Debug(msg) - } else { - logger.Warn(msg) - } -} - func NewServer(cfg *config.Config) (*Server, error) { serverConfig, err := newServerConfig(cfg) if err != nil { @@ -171,6 +158,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { defer cancel() ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr}) + ctxlog.Debug("server: handleConn: start") // Prevent a panic in a single connection from taking out the whole server defer func() { @@ -181,22 +169,8 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { } }() - ctxlog.Debug("server: handleConn: start") - - sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx)) - if err != nil { - logSSHInitError(ctxlog, err) - return - } - go ssh.DiscardRequests(reqs) - - started := time.Now() - var establishSessionDuration float64 - conn := newConnection(s.Config, remoteAddr, sconn) - conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) error { - establishSessionDuration = time.Since(started).Seconds() - metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) - + conn := newConnection(s.Config, nconn) + conn.handle(ctx, s.serverConfig.get(ctx), func(sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error { session := &session{ cfg: s.Config, channel: channel, @@ -206,13 +180,6 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { return session.handle(ctx, requests) }) - - reason := sconn.Wait() - ctxlog.WithFields(log.Fields{ - "duration_s": time.Since(started).Seconds(), - "establish_session_duration_s": establishSessionDuration, - "reason": reason, - }).Info("server: handleConn: done") } func (s *Server) requirePolicy(_ net.Addr) (proxyproto.Policy, error) { |