diff options
-rw-r--r-- | internal/sshd/sshd.go | 5 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 42 |
2 files changed, 42 insertions, 5 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 19fa661..d765faf 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -94,6 +94,7 @@ func (s *Server) listen(ctx context.Context) error { if s.Config.Server.ProxyProtocol { sshListener = &proxyproto.Listener{ Listener: sshListener, + Policy: unconditionalRequirePolicy, ReadHeaderTimeout: ProxyHeaderTimeout, } @@ -185,3 +186,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { ctxlog.Info("server: handleConn: done") } + +func unconditionalRequirePolicy(_ net.Addr) (proxyproto.Policy, error) { + return proxyproto.REQUIRE, nil +} diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index 71f7733..455a830 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -47,6 +47,19 @@ func TestListenAndServe(t *testing.T) { verifyStatus(t, s, StatusClosed) } +func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testing.T) { + s := setupServerWithProxyProtocolEnabled(t) + defer s.Shutdown() + + client, err := ssh.Dial("tcp", serverUrl, clientConfig(t)) + if client != nil { + client.Close() + } + + require.Error(t, err, "Expected plain SSH request to be failed") + require.Equal(t, err.Error(), "ssh: handshake failed: EOF") +} + func TestCorrelationId(t *testing.T) { setupServer(t) @@ -125,6 +138,18 @@ func TestInvalidServerConfig(t *testing.T) { func setupServer(t *testing.T) *Server { t.Helper() + return setupServerWithConfig(t, nil) +} + +func setupServerWithProxyProtocolEnabled(t *testing.T) *Server { + t.Helper() + + return setupServerWithConfig(t, &config.Config{Server: config.ServerConfig{ProxyProtocol: true}}) +} + +func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server { + t.Helper() + requests := []testserver.TestRequestHandler{ { Path: "/api/v4/internal/authorized_keys", @@ -148,13 +173,20 @@ func setupServer(t *testing.T) *Server { testhelper.PrepareTestRootDir(t) url := testserver.StartSocketHttpServer(t, requests) - srvCfg := config.ServerConfig{ - Listen: serverUrl, - ConcurrentSessionsLimit: 1, - HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}, + + if cfg == nil { + cfg = &config.Config{} } - s, err := NewServer(&config.Config{User: user, RootDir: "/tmp", GitlabUrl: url, Server: srvCfg}) + // All things that don't need to be configurable in tests yet + cfg.GitlabUrl = url + cfg.RootDir = "/tmp" + cfg.User = user + cfg.Server.Listen = serverUrl + cfg.Server.ConcurrentSessionsLimit = 1 + cfg.Server.HostKeyFiles = []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")} + + s, err := NewServer(cfg) require.NoError(t, err) go func() { require.NoError(t, s.ListenAndServe(context.Background())) }() |