diff options
Diffstat (limited to 'internal/sshd/connection.go')
-rw-r--r-- | internal/sshd/connection.go | 82 |
1 files changed, 62 insertions, 20 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 61234a3..eaae5ca 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -3,6 +3,8 @@ package sshd import ( "context" "errors" + "net" + "strings" "time" "golang.org/x/crypto/ssh" @@ -22,52 +24,91 @@ const KeepAliveMsg = "keepalive@openssh.com" var EOFTimeout = 10 * time.Second type connection struct { - cfg *config.Config - concurrentSessions *semaphore.Weighted - remoteAddr string - sconn *ssh.ServerConn - maxSessions int64 + cfg *config.Config + concurrentSessions *semaphore.Weighted + nconn net.Conn + maxSessions int64 + remoteAddr string + started time.Time + establishSessionDuration float64 } -type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error +type channelHandler func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error -func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection { +func newConnection(cfg *config.Config, nconn net.Conn) *connection { maxSessions := cfg.Server.ConcurrentSessionsLimit return &connection{ cfg: cfg, maxSessions: maxSessions, concurrentSessions: semaphore.NewWeighted(maxSessions), - remoteAddr: remoteAddr, - sconn: sconn, + nconn: nconn, + remoteAddr: nconn.RemoteAddr().String(), + started: time.Now(), } } -func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { - ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) +func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) { + sconn, chans, err := c.initServerConn(ctx, srvCfg) + if err != nil { + return + } if c.cfg.Server.ClientAliveInterval > 0 { ticker := time.NewTicker(time.Duration(c.cfg.Server.ClientAliveInterval)) defer ticker.Stop() - go c.sendKeepAliveMsg(ctx, ticker) + go c.sendKeepAliveMsg(ctx, sconn, ticker) + } + + c.handleRequests(ctx, sconn, chans, handler) + + reason := sconn.Wait() + log.WithContextFields(ctx, log.Fields{ + "duration_s": time.Since(c.started).Seconds(), + "establish_session_duration_s": c.establishSessionDuration, + "reason": reason, + }).Info("server: handleConn: done") +} + +func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) { + sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, srvCfg) + if err != nil { + msg := "connection: initServerConn: failed to initialize SSH connection" + + logger := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}).WithError(err) + + if strings.Contains(err.Error(), "no common algorithm for host key") || err.Error() == "EOF" { + logger.Debug(msg) + } else { + logger.Warn(msg) + } + + return nil, nil, err } + go ssh.DiscardRequests(reqs) + + return sconn, chans, err +} + +func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) { + ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested") if newChannel.ChannelType() != "session" { - ctxlog.Info("connection: handle: unknown channel type") + ctxlog.Info("connection: handleRequests: unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } if !c.concurrentSessions.TryAcquire(1) { - ctxlog.Info("connection: handle: too many concurrent sessions") + ctxlog.Info("connection: handleRequests: too many concurrent sessions") newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") metrics.SshdHitMaxSessions.Inc() continue } channel, requests, err := newChannel.Accept() if err != nil { - ctxlog.WithError(err).Error("connection: handle: accepting channel failed") + ctxlog.WithError(err).Error("connection: handleRequests: accepting channel failed") c.concurrentSessions.Release(1) continue } @@ -76,6 +117,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha defer func(started time.Time) { metrics.SshdSessionDuration.Observe(time.Since(started).Seconds()) }(time.Now()) + c.establishSessionDuration = time.Since(c.started).Seconds() defer c.concurrentSessions.Release(1) @@ -87,12 +129,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha }() metrics.SliSshdSessionsTotal.Inc() - err := handler(ctx, channel, requests) + err := handler(sconn, channel, requests) if err != nil { c.trackError(err) } - ctxlog.Info("connection: handle: done") + ctxlog.Info("connection: handleRequests: done") }() } @@ -105,7 +147,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha c.concurrentSessions.Acquire(ctx, c.maxSessions) } -func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) { +func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) { ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for { @@ -113,9 +155,9 @@ func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) case <-ctx.Done(): return case <-ticker.C: - ctxlog.Debug("session: handleShell: send keepalive message to a client") + ctxlog.Debug("connection: sendKeepAliveMsg: send keepalive message to a client") - c.sconn.SendRequest(KeepAliveMsg, true, nil) + sconn.SendRequest(KeepAliveMsg, true, nil) } } } |