diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2022-05-10 23:16:22 +0400 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2022-05-10 23:23:53 +0400 |
commit | 709c5dd75a7c1a2a0f3296d76ddc654191841213 (patch) | |
tree | d80a8b1ed3d340116770122b99b56bf43d2bad88 /internal | |
parent | 733845f9abec43b6573ba3a1167cc27ff2bfc199 (diff) | |
download | gitlab-shell-709c5dd75a7c1a2a0f3296d76ddc654191841213.tar.gz |
Make PROXY policy configurable
It would give us more flexibility when we decide to enable
PROXY protocol
Diffstat (limited to 'internal')
-rw-r--r-- | internal/config/config.go | 1 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 18 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 106 |
3 files changed, 110 insertions, 15 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index ff0c79a..ab88d72 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,6 +24,7 @@ const ( type ServerConfig struct { Listen string `yaml:"listen,omitempty"` ProxyProtocol bool `yaml:"proxy_protocol,omitempty"` + ProxyPolicy string `yaml:"proxy_policy,omitempty"` WebListen string `yaml:"web_listen,omitempty"` ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"` GracePeriodSeconds uint64 `yaml:"grace_period"` diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 49b8ab9..c2758f0 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strings" "sync" "time" @@ -95,7 +96,7 @@ func (s *Server) listen(ctx context.Context) error { if s.Config.Server.ProxyProtocol { sshListener = &proxyproto.Listener{ Listener: sshListener, - Policy: unconditionalRequirePolicy, + Policy: s.requirePolicy, ReadHeaderTimeout: ProxyHeaderTimeout, } @@ -210,6 +211,17 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { }).Info("server: handleConn: done") } -func unconditionalRequirePolicy(_ net.Addr) (proxyproto.Policy, error) { - return proxyproto.REQUIRE, nil +func (s *Server) requirePolicy(_ net.Addr) (proxyproto.Policy, error) { + // Set the Policy value based on config + // Values are taken from https://github.com/pires/go-proxyproto/blob/195fedcfbfc1be163f3a0d507fac1709e9d81fed/policy.go#L20 + switch strings.ToLower(s.Config.Server.ProxyPolicy) { + case "require": + return proxyproto.REQUIRE, nil + case "ignore": + return proxyproto.IGNORE, nil + case "reject": + return proxyproto.REJECT, nil + default: + return proxyproto.USE, nil + } } diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index 0c6a8ec..d725add 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -3,6 +3,7 @@ package sshd import ( "context" "fmt" + "net" "net/http" "net/http/httptest" "os" @@ -10,6 +11,7 @@ import ( "testing" "time" + "github.com/pires/go-proxyproto" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -48,15 +50,101 @@ func TestListenAndServe(t *testing.T) { } func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testing.T) { - setupServerWithProxyProtocolEnabled(t) + target, err := net.ResolveTCPAddr("tcp", serverUrl) + require.NoError(t, err) - client, err := ssh.Dial("tcp", serverUrl, clientConfig(t)) - if client != nil { - client.Close() + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: target, + } + + testCases := []struct { + desc string + proxyPolicy string + header *proxyproto.Header + isRejected bool + }{ + { + desc: "USE (default) without a header", + proxyPolicy: "", + header: nil, + isRejected: false, + }, + { + desc: "USE (default) with a header", + proxyPolicy: "", + header: header, + isRejected: false, + }, + { + desc: "REQUIRE without a header", + proxyPolicy: "require", + header: nil, + isRejected: true, + }, + { + desc: "REQUIRE with a header", + proxyPolicy: "require", + header: header, + isRejected: false, + }, + { + desc: "REJECT without a header", + proxyPolicy: "reject", + header: nil, + isRejected: false, + }, + { + desc: "REJECT with a header", + proxyPolicy: "reject", + header: header, + isRejected: true, + }, + { + desc: "IGNORE without a header", + proxyPolicy: "ignore", + header: nil, + isRejected: false, + }, + { + desc: "IGNORE with a header", + proxyPolicy: "ignore", + header: header, + isRejected: false, + }, } - require.Error(t, err, "Expected plain SSH request to be failed") - require.Regexp(t, "ssh: handshake failed", err.Error()) + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + setupServerWithConfig(t, &config.Config{Server: config.ServerConfig{ProxyProtocol: true, ProxyPolicy: tc.proxyPolicy}}) + + conn, err := net.DialTCP("tcp", nil, target) + require.NoError(t, err) + + if tc.header != nil { + _, err := header.WriteTo(conn) + require.NoError(t, err) + } + + sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t)) + if sshConn != nil { + sshConn.Close() + } + + if tc.isRejected { + require.Error(t, err, "Expected plain SSH request to be failed") + require.Regexp(t, "ssh: handshake failed", err.Error()) + } else { + require.NoError(t, err) + } + }) + } } func TestCorrelationId(t *testing.T) { @@ -140,12 +228,6 @@ func setupServer(t *testing.T) *Server { return setupServerWithConfig(t, nil) } -func setupServerWithProxyProtocolEnabled(t *testing.T) *Server { - t.Helper() - - return setupServerWithConfig(t, &config.Config{Server: config.ServerConfig{ProxyProtocol: true}}) -} - func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server { t.Helper() |