diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2021-09-15 18:58:27 +0300 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2021-09-15 19:41:33 +0300 |
commit | e96e13301904bfa6eb514667df9a7803828a7da9 (patch) | |
tree | 5713e182841a41acc9e2adfd48697cbfe9eb8e67 | |
parent | 7d60d7a09658041c959c92a7776feceb64b735f4 (diff) | |
download | gitlab-shell-e96e13301904bfa6eb514667df9a7803828a7da9.tar.gz |
Extract server config related code out of sshd.go
-rw-r--r-- | internal/sshd/server_config.go | 94 | ||||
-rw-r--r-- | internal/sshd/server_config_test.go | 105 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 75 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 16 |
4 files changed, 217 insertions, 73 deletions
diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go new file mode 100644 index 0000000..7306944 --- /dev/null +++ b/internal/sshd/server_config.go @@ -0,0 +1,94 @@ +package sshd + +import ( + "context" + "encoding/base64" + "fmt" + "os" + "strconv" + "time" + + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" + + "gitlab.com/gitlab-org/labkit/log" +) + +type serverConfig struct { + cfg *config.Config + hostKeys []ssh.Signer + authorizedKeysClient *authorizedkeys.Client +} + +func newServerConfig(cfg *config.Config) (*serverConfig, error) { + authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + } + + var hostKeys []ssh.Signer + for _, filename := range cfg.Server.HostKeyFiles { + keyRaw, err := os.ReadFile(filename) + if err != nil { + log.WithError(err).Warnf("Failed to read host key %v", filename) + continue + } + key, err := ssh.ParsePrivateKey(keyRaw) + if err != nil { + log.WithError(err).Warnf("Failed to parse host key %v", filename) + continue + } + + hostKeys = append(hostKeys, key) + } + if len(hostKeys) == 0 { + return nil, fmt.Errorf("No host keys could be loaded, aborting") + } + + return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil +} + +func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) { + if user != s.cfg.User { + return nil, fmt.Errorf("unknown user") + } + if key.Type() == ssh.KeyAlgoDSA { + return nil, fmt.Errorf("DSA is prohibited") + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) + if err != nil { + return nil, err + } + + return res, nil +} + +func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig { + sshCfg := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + res, err := s.getAuthKey(ctx, conn.User(), key) + if err != nil { + return nil, err + } + + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "key-id": strconv.FormatInt(res.Id, 10), + }, + }, nil + }, + } + + for _, key := range s.hostKeys { + sshCfg.AddHostKey(key) + } + + return sshCfg +} diff --git a/internal/sshd/server_config_test.go b/internal/sshd/server_config_test.go new file mode 100644 index 0000000..58bd3e1 --- /dev/null +++ b/internal/sshd/server_config_test.go @@ -0,0 +1,105 @@ +package sshd + +import ( + "context" + "crypto/dsa" + "crypto/rand" + "crypto/rsa" + "path" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +func TestNewServerConfigWithoutHosts(t *testing.T) { + _, err := newServerConfig(&config.Config{GitlabUrl: "http://localhost"}) + + require.Error(t, err) + require.Equal(t, "No host keys could be loaded, aborting", err.Error()) +} + +func TestFailedAuthorizedKeysClient(t *testing.T) { + _, err := newServerConfig(&config.Config{GitlabUrl: "ftp://localhost"}) + + require.Error(t, err) + require.Equal(t, "failed to initialize GitLab client: Error creating http client: unknown GitLab URL prefix", err.Error()) +} + +func TestFailedGetAuthKey(t *testing.T) { + testhelper.PrepareTestRootDir(t) + + srvCfg := config.ServerConfig{ + Listen: "127.0.0.1", + ConcurrentSessionsLimit: 1, + HostKeyFiles: []string{ + path.Join(testhelper.TestRoot, "certs/valid/server.key"), + path.Join(testhelper.TestRoot, "certs/invalid-path.key"), + path.Join(testhelper.TestRoot, "certs/invalid/server.crt"), + }, + } + + cfg, err := newServerConfig( + &config.Config{GitlabUrl: "http://localhost", User: "user", Server: srvCfg}, + ) + require.NoError(t, err) + + testCases := []struct { + desc string + user string + key ssh.PublicKey + expectedError string + }{ + { + desc: "wrong user", + user: "wrong-user", + key: rsaPublicKey(t), + expectedError: "unknown user", + }, { + desc: "prohibited dsa key", + user: "user", + key: dsaPublicKey(t), + expectedError: "DSA is prohibited", + }, { + desc: "API error", + user: "user", + key: rsaPublicKey(t), + expectedError: "Internal API unreachable", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + _, err = cfg.getAuthKey(context.Background(), tc.user, tc.key) + require.Error(t, err) + require.Equal(t, tc.expectedError, err.Error()) + }) + } +} + +func rsaPublicKey(t *testing.T) ssh.PublicKey { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + + return publicKey +} + +func dsaPublicKey(t *testing.T) ssh.PublicKey { + privateKey := new(dsa.PrivateKey) + params := new(dsa.Parameters) + require.NoError(t, dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160)) + + privateKey.PublicKey.Parameters = *params + require.NoError(t, dsa.GenerateKey(privateKey, rand.Reader)) + + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + + return publicKey +} diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index de5fbd4..ff9e765 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -2,13 +2,9 @@ package sshd import ( "context" - "encoding/base64" - "errors" "fmt" "net" "net/http" - "os" - "strconv" "sync" "time" @@ -16,7 +12,6 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/internal/config" - "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" @@ -35,40 +30,20 @@ const ( type Server struct { Config *config.Config - status status - statusMu sync.Mutex - wg sync.WaitGroup - listener net.Listener - hostKeys []ssh.Signer - authorizedKeysClient *authorizedkeys.Client + status status + statusMu sync.Mutex + wg sync.WaitGroup + listener net.Listener + serverConfig *serverConfig } func NewServer(cfg *config.Config) (*Server, error) { - authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + serverConfig, err := newServerConfig(cfg) if err != nil { - return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + return nil, err } - var hostKeys []ssh.Signer - for _, filename := range cfg.Server.HostKeyFiles { - keyRaw, err := os.ReadFile(filename) - if err != nil { - log.WithError(err).Warnf("Failed to read host key %v", filename) - continue - } - key, err := ssh.ParsePrivateKey(keyRaw) - if err != nil { - log.WithError(err).Warnf("Failed to parse host key %v", filename) - continue - } - - hostKeys = append(hostKeys, key) - } - if len(hostKeys) == 0 { - return nil, fmt.Errorf("No host keys could be loaded, aborting") - } - - return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil + return &Server{Config: cfg, serverConfig: serverConfig}, nil } func (s *Server) ListenAndServe(ctx context.Context) error { @@ -168,38 +143,6 @@ func (s *Server) getStatus() status { return s.status } -func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig { - sshCfg := &ssh.ServerConfig{ - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if conn.User() != s.Config.User { - return nil, errors.New("unknown user") - } - if key.Type() == ssh.KeyAlgoDSA { - return nil, errors.New("DSA is prohibited") - } - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) - if err != nil { - return nil, err - } - - return &ssh.Permissions{ - // Record the public key used for authentication. - Extensions: map[string]string{ - "key-id": strconv.FormatInt(res.Id, 10), - }, - }, nil - }, - } - - for _, key := range s.hostKeys { - sshCfg.AddHostKey(key) - } - - return sshCfg -} - func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { remoteAddr := nconn.RemoteAddr().String() @@ -216,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) defer cancel() - sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx)) + sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx)) if err != nil { log.WithError(err).Info("Failed to initialize SSH connection") return diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index cba1c3f..71f7733 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -104,13 +104,6 @@ func TestLivenessProbe(t *testing.T) { require.Equal(t, 200, r.Result().StatusCode) } -func TestNewServerWithoutHosts(t *testing.T) { - _, err := NewServer(&config.Config{GitlabUrl: "http://localhost"}) - - require.Error(t, err) - require.Equal(t, "No host keys could be loaded, aborting", err.Error()) -} - func TestInvalidClientConfig(t *testing.T) { setupServer(t) @@ -120,6 +113,15 @@ func TestInvalidClientConfig(t *testing.T) { require.Error(t, err) } +func TestInvalidServerConfig(t *testing.T) { + s := &Server{Config: &config.Config{Server: config.ServerConfig{Listen: "invalid"}}} + err := s.ListenAndServe(context.Background()) + + require.Error(t, err) + require.Equal(t, "failed to listen for connection: listen tcp: address invalid: missing port in address", err.Error()) + require.Nil(t, s.Shutdown()) +} + func setupServer(t *testing.T) *Server { t.Helper() |