summaryrefslogtreecommitdiff
path: root/internal/sshd/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/connection.go')
-rw-r--r--internal/sshd/connection.go82
1 files changed, 62 insertions, 20 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 61234a3..eaae5ca 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -3,6 +3,8 @@ package sshd
import (
"context"
"errors"
+ "net"
+ "strings"
"time"
"golang.org/x/crypto/ssh"
@@ -22,52 +24,91 @@ const KeepAliveMsg = "keepalive@openssh.com"
var EOFTimeout = 10 * time.Second
type connection struct {
- cfg *config.Config
- concurrentSessions *semaphore.Weighted
- remoteAddr string
- sconn *ssh.ServerConn
- maxSessions int64
+ cfg *config.Config
+ concurrentSessions *semaphore.Weighted
+ nconn net.Conn
+ maxSessions int64
+ remoteAddr string
+ started time.Time
+ establishSessionDuration float64
}
-type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error
+type channelHandler func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
-func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
+func newConnection(cfg *config.Config, nconn net.Conn) *connection {
maxSessions := cfg.Server.ConcurrentSessionsLimit
return &connection{
cfg: cfg,
maxSessions: maxSessions,
concurrentSessions: semaphore.NewWeighted(maxSessions),
- remoteAddr: remoteAddr,
- sconn: sconn,
+ nconn: nconn,
+ remoteAddr: nconn.RemoteAddr().String(),
+ started: time.Now(),
}
}
-func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
- ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) {
+ sconn, chans, err := c.initServerConn(ctx, srvCfg)
+ if err != nil {
+ return
+ }
if c.cfg.Server.ClientAliveInterval > 0 {
ticker := time.NewTicker(time.Duration(c.cfg.Server.ClientAliveInterval))
defer ticker.Stop()
- go c.sendKeepAliveMsg(ctx, ticker)
+ go c.sendKeepAliveMsg(ctx, sconn, ticker)
+ }
+
+ c.handleRequests(ctx, sconn, chans, handler)
+
+ reason := sconn.Wait()
+ log.WithContextFields(ctx, log.Fields{
+ "duration_s": time.Since(c.started).Seconds(),
+ "establish_session_duration_s": c.establishSessionDuration,
+ "reason": reason,
+ }).Info("server: handleConn: done")
+}
+
+func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) {
+ sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, srvCfg)
+ if err != nil {
+ msg := "connection: initServerConn: failed to initialize SSH connection"
+
+ logger := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}).WithError(err)
+
+ if strings.Contains(err.Error(), "no common algorithm for host key") || err.Error() == "EOF" {
+ logger.Debug(msg)
+ } else {
+ logger.Warn(msg)
+ }
+
+ return nil, nil, err
}
+ go ssh.DiscardRequests(reqs)
+
+ return sconn, chans, err
+}
+
+func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
if newChannel.ChannelType() != "session" {
- ctxlog.Info("connection: handle: unknown channel type")
+ ctxlog.Info("connection: handleRequests: unknown channel type")
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !c.concurrentSessions.TryAcquire(1) {
- ctxlog.Info("connection: handle: too many concurrent sessions")
+ ctxlog.Info("connection: handleRequests: too many concurrent sessions")
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
metrics.SshdHitMaxSessions.Inc()
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
- ctxlog.WithError(err).Error("connection: handle: accepting channel failed")
+ ctxlog.WithError(err).Error("connection: handleRequests: accepting channel failed")
c.concurrentSessions.Release(1)
continue
}
@@ -76,6 +117,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
defer func(started time.Time) {
metrics.SshdSessionDuration.Observe(time.Since(started).Seconds())
}(time.Now())
+ c.establishSessionDuration = time.Since(c.started).Seconds()
defer c.concurrentSessions.Release(1)
@@ -87,12 +129,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
}()
metrics.SliSshdSessionsTotal.Inc()
- err := handler(ctx, channel, requests)
+ err := handler(sconn, channel, requests)
if err != nil {
c.trackError(err)
}
- ctxlog.Info("connection: handle: done")
+ ctxlog.Info("connection: handleRequests: done")
}()
}
@@ -105,7 +147,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
c.concurrentSessions.Acquire(ctx, c.maxSessions)
}
-func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) {
+func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for {
@@ -113,9 +155,9 @@ func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker)
case <-ctx.Done():
return
case <-ticker.C:
- ctxlog.Debug("session: handleShell: send keepalive message to a client")
+ ctxlog.Debug("connection: sendKeepAliveMsg: send keepalive message to a client")
- c.sconn.SendRequest(KeepAliveMsg, true, nil)
+ sconn.SendRequest(KeepAliveMsg, true, nil)
}
}
}