diff options
-rw-r--r-- | internal/sshd/sshd.go | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index c61d527..d20286a 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -95,9 +95,14 @@ func (s *Server) listen(ctx context.Context) error { } if s.Config.Server.ProxyProtocol { + policy, err := s.proxyPolicy() + if err != nil { + return fmt.Errorf("invalid policy configuration: %w", err) + } + sshListener = &proxyproto.Listener{ Listener: sshListener, - Policy: s.requirePolicy(), + Policy: policy, ReadHeaderTimeout: time.Duration(s.Config.Server.ProxyHeaderTimeout), } @@ -200,22 +205,22 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { }) } -func (s *Server) requirePolicy() proxyproto.PolicyFunc { +func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { if len(s.Config.Server.ProxyAllowed) > 0 { - return proxyproto.MustStrictWhiteListPolicy(s.Config.Server.ProxyAllowed) + return proxyproto.StrictWhiteListPolicy(s.Config.Server.ProxyAllowed) } // Set the Policy value based on config // Values are taken from https://github.com/pires/go-proxyproto/blob/195fedcfbfc1be163f3a0d507fac1709e9d81fed/policy.go#L20 switch strings.ToLower(s.Config.Server.ProxyPolicy) { case "require": - return staticProxyPolicy(proxyproto.REQUIRE) + return staticProxyPolicy(proxyproto.REQUIRE), nil case "ignore": - return staticProxyPolicy(proxyproto.IGNORE) + return staticProxyPolicy(proxyproto.IGNORE), nil case "reject": - return staticProxyPolicy(proxyproto.REJECT) + return staticProxyPolicy(proxyproto.REJECT), nil default: - return staticProxyPolicy(proxyproto.USE) + return staticProxyPolicy(proxyproto.USE), nil } } |