diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2021-07-27 14:52:45 +0300 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2021-07-27 17:46:11 +0300 |
commit | f6baecaa794ef85b144fa9cd05940e3f020b4a0e (patch) | |
tree | 034e4ae73aa5522a73db0506e67f567648bc507f | |
parent | f9e7ffda68192d24ff26f0d5ff7fe70e376c32f2 (diff) | |
download | gitlab-shell-f6baecaa794ef85b144fa9cd05940e3f020b4a0e.tar.gz |
Sshd: Log same correlation_id on auth keys
-rw-r--r-- | cmd/gitlab-sshd/main.go | 7 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 85 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 120 | ||||
-rw-r--r-- | internal/testhelper/testdata/testroot/certs/valid/server_authorized_key | 1 |
4 files changed, 157 insertions, 56 deletions
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go index d1cc84e..78690b0 100644 --- a/cmd/gitlab-sshd/main.go +++ b/cmd/gitlab-sshd/main.go @@ -68,7 +68,10 @@ func main() { ctx, finished := command.Setup("gitlab-sshd", cfg) defer finished() - server := sshd.Server{Config: cfg} + server, err := sshd.NewServer(cfg) + if err != nil { + log.WithError(err).Fatal("Failed to start GitLab built-in sshd") + } // Startup monitoring endpoint. if cfg.Server.WebListen != "" { @@ -104,6 +107,6 @@ func main() { }() if err := server.ListenAndServe(ctx); err != nil { - log.WithError(err).Fatal("Failed to start GitLab built-in sshd") + log.WithError(err).Fatal("GitLab built-in sshd failed to listen for new connections") } } diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 8b49712..b918109 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -35,10 +35,40 @@ const ( type Server struct { Config *config.Config - status status - statusMu sync.Mutex - wg sync.WaitGroup - listener net.Listener + status status + statusMu sync.Mutex + wg sync.WaitGroup + listener net.Listener + hostKeys []ssh.Signer + authorizedKeysClient *authorizedkeys.Client +} + +func NewServer(cfg *config.Config) (*Server, error) { + authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + } + + var hostKeys []ssh.Signer + for _, filename := range cfg.Server.HostKeyFiles { + keyRaw, err := ioutil.ReadFile(filename) + if err != nil { + log.WithError(err).Warnf("Failed to read host key %v", filename) + continue + } + key, err := ssh.ParsePrivateKey(keyRaw) + if err != nil { + log.WithError(err).Warnf("Failed to parse host key %v", filename) + continue + } + + hostKeys = append(hostKeys, key) + } + if len(hostKeys) == 0 { + return nil, fmt.Errorf("No host keys could be loaded, aborting") + } + + return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil } func (s *Server) ListenAndServe(ctx context.Context) error { @@ -47,7 +77,9 @@ func (s *Server) ListenAndServe(ctx context.Context) error { } defer s.listener.Close() - return s.serve(ctx) + s.serve(ctx) + + return nil } func (s *Server) Shutdown() error { @@ -100,12 +132,7 @@ func (s *Server) listen() error { return nil } -func (s *Server) serve(ctx context.Context) error { - sshCfg, err := s.initConfig(ctx) - if err != nil { - return err - } - +func (s *Server) serve(ctx context.Context) { s.changeStatus(StatusReady) for { @@ -120,14 +147,12 @@ func (s *Server) serve(ctx context.Context) error { } s.wg.Add(1) - go s.handleConn(ctx, sshCfg, nconn) + go s.handleConn(ctx, nconn) } s.wg.Wait() s.changeStatus(StatusClosed) - - return nil } func (s *Server) changeStatus(st status) { @@ -143,12 +168,7 @@ func (s *Server) getStatus() status { return s.status } -func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { - authorizedKeysClient, err := authorizedkeys.NewClient(s.Config) - if err != nil { - return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) - } - +func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig { sshCfg := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if conn.User() != s.Config.User { @@ -159,7 +179,7 @@ func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { } ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - res, err := authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) + res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) if err != nil { return nil, err } @@ -173,29 +193,14 @@ func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { }, } - var loadedHostKeys uint - for _, filename := range s.Config.Server.HostKeyFiles { - keyRaw, err := ioutil.ReadFile(filename) - if err != nil { - log.WithError(err).Warnf("Failed to read host key %v", filename) - continue - } - key, err := ssh.ParsePrivateKey(keyRaw) - if err != nil { - log.WithError(err).Warnf("Failed to parse host key %v", filename) - continue - } - loadedHostKeys++ + for _, key := range s.hostKeys { sshCfg.AddHostKey(key) } - if loadedHostKeys == 0 { - return nil, fmt.Errorf("No host keys could be loaded, aborting") - } - return sshCfg, nil + return sshCfg } -func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) { +func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { remoteAddr := nconn.RemoteAddr().String() defer s.wg.Done() @@ -211,7 +216,7 @@ func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) defer cancel() - sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) + sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx)) if err != nil { log.WithError(err).Info("Failed to initialize SSH connection") return diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index e5f6111..2923737 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -2,37 +2,71 @@ package sshd import ( "context" + "fmt" + "io/ioutil" + "net/http" "net/http/httptest" "path" "testing" "time" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/client/testserver" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" ) -const serverUrl = "127.0.0.1:50000" - -func TestShutdown(t *testing.T) { - s := setupServer(t) +const ( + serverUrl = "127.0.0.1:50000" + user = "git" +) - go func() { require.NoError(t, s.ListenAndServe(context.Background())) }() +var ( + correlationId = "" +) - verifyStatus(t, s, StatusReady) +func TestListenAndServe(t *testing.T) { + s := setupServer(t) - s.wg.Add(1) + client, err := ssh.Dial("tcp", serverUrl, clientConfig(t)) + require.NoError(t, err) + defer client.Close() require.NoError(t, s.Shutdown()) verifyStatus(t, s, StatusOnShutdown) - s.wg.Done() + holdSession(t, client) + + _, err = ssh.Dial("tcp", serverUrl, clientConfig(t)) + require.Equal(t, err.Error(), "dial tcp 127.0.0.1:50000: connect: connection refused") + + client.Close() verifyStatus(t, s, StatusClosed) } +func TestCorrelationId(t *testing.T) { + setupServer(t) + + client, err := ssh.Dial("tcp", serverUrl, clientConfig(t)) + require.NoError(t, err) + defer client.Close() + + holdSession(t, client) + + previousCorrelationId := correlationId + + client, err = ssh.Dial("tcp", serverUrl, clientConfig(t)) + require.NoError(t, err) + defer client.Close() + + holdSession(t, client) + + require.NotEqual(t, previousCorrelationId, correlationId) +} + func TestReadinessProbe(t *testing.T) { s := &Server{Config: &config.Config{Server: config.DefaultServerConfig}} @@ -71,17 +105,75 @@ func TestLivenessProbe(t *testing.T) { } func setupServer(t *testing.T) *Server { + t.Helper() + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/authorized_keys", + Handler: func(w http.ResponseWriter, r *http.Request) { + correlationId = r.Header.Get("X-Request-Id") + + require.NotEmpty(t, correlationId) + + fmt.Fprint(w, `{"id": 1000, "key": "key"}`) + }, + }, { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, correlationId, r.Header.Get("X-Request-Id")) + + fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`) + }, + }, + } + testhelper.PrepareTestRootDir(t) - url := testserver.StartSocketHttpServer(t, []testserver.TestRequestHandler{}) + url := testserver.StartSocketHttpServer(t, requests) srvCfg := config.ServerConfig{ - Listen: serverUrl, - HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}, + Listen: serverUrl, + ConcurrentSessionsLimit: 1, + HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}, + } + + s, err := NewServer(&config.Config{User: user, RootDir: "/tmp", GitlabUrl: url, Server: srvCfg}) + require.NoError(t, err) + + go func() { require.NoError(t, s.ListenAndServe(context.Background())) }() + t.Cleanup(func() { s.Shutdown() }) + + verifyStatus(t, s, StatusReady) + + return s +} + +func clientConfig(t *testing.T) *ssh.ClientConfig { + keyRaw, err := ioutil.ReadFile(path.Join(testhelper.TestRoot, "certs/valid/server_authorized_key")) + pKey, _, _, _, err := ssh.ParseAuthorizedKey(keyRaw) + require.NoError(t, err) + + key, err := ioutil.ReadFile(path.Join(testhelper.TestRoot, "certs/client/key.pem")) + require.NoError(t, err) + signer, err := ssh.ParsePrivateKey(key) + require.NoError(t, err) + + return &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.FixedHostKey(pKey), } +} - cfg := &config.Config{RootDir: "/tmp", GitlabUrl: url, Server: srvCfg} +func holdSession(t *testing.T, c *ssh.Client) { + session, err := c.NewSession() + require.NoError(t, err) + defer session.Close() - return &Server{Config: cfg} + output, err := session.Output("discover") + require.NoError(t, err) + require.Equal(t, "Welcome to GitLab, @test-user!\n", string(output)) } func verifyStatus(t *testing.T, s *Server, st status) { @@ -94,5 +186,5 @@ func verifyStatus(t *testing.T, s *Server, st status) { time.Sleep(time.Duration(i) * time.Millisecond) } - require.Equal(t, s.getStatus(), st) + require.Equal(t, st, s.getStatus()) } diff --git a/internal/testhelper/testdata/testroot/certs/valid/server_authorized_key b/internal/testhelper/testdata/testroot/certs/valid/server_authorized_key new file mode 100644 index 0000000..784d80c --- /dev/null +++ b/internal/testhelper/testdata/testroot/certs/valid/server_authorized_key @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCa17cb94P6q5qbDIWX7aMSjyeBIBPQVZ5jlkDBG90XgWC1MEu9sB1OfKLukcx6wJJSTLFccc9rMzhINXq6K7ks0oXSLP81jvqsu0WipIZSDKBNkdVtno1FcI1RnQ+yUP3nA4Ja9L233GA1evLrqTz6Z9k2ET5wVB+s7+k3lak24bJZN8qVRDDk1UveahuPe1KMj7DNKls8y9tNCgGJn9UeTLJzXlh2tt4/AUHZ0lvET9eCzKT9PBZJQWcCzqLXHa37jbc0ib2sgNN1bZhgkle/cxRx0MjEmdjRt4Z48wjKaf1khFQm0r9lebAxvna/vT5hNywbru5KbfUJHyM23yql |