summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
authorMikhail Mazurskiy <mmazurskiy@gitlab.com>2021-09-09 21:05:25 +1000
committerMikhail Mazurskiy <mmazurskiy@gitlab.com>2021-09-09 21:05:25 +1000
commit44737afce375d68a3c122991f6d94e3a84233dbb (patch)
tree7d5b299ce6bcddcb0bdf33fa24a1c42c5e65e882 /client
parent5edb579c23a06a2795c199478c88782b25f34d0d (diff)
downloadgitlab-shell-ash2k/use-moved-gitlab-client.tar.gz
Use moved GitLab client from Gitalyash2k/use-moved-gitlab-client
See https://gitlab.com/gitlab-org/gitaly/-/merge_requests/3850
Diffstat (limited to 'client')
-rw-r--r--client/client_test.go238
-rw-r--r--client/gitlabnet.go161
-rw-r--r--client/httpclient.go189
-rw-r--r--client/httpclient_test.go133
-rw-r--r--client/httpsclient_test.go133
-rw-r--r--client/testserver/gitalyserver.go5
-rw-r--r--client/testserver/testserver.go37
7 files changed, 4 insertions, 892 deletions
diff --git a/client/client_test.go b/client/client_test.go
deleted file mode 100644
index 48681f7..0000000
--- a/client/client_test.go
+++ /dev/null
@@ -1,238 +0,0 @@
-package client
-
-import (
- "context"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "path"
- "strings"
- "testing"
-
- "github.com/stretchr/testify/require"
-
- "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
- "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
-)
-
-func TestClients(t *testing.T) {
- testhelper.PrepareTestRootDir(t)
-
- testCases := []struct {
- desc string
- relativeURLRoot string
- caFile string
- server func(*testing.T, []testserver.TestRequestHandler) string
- }{
- {
- desc: "Socket client",
- server: testserver.StartSocketHttpServer,
- },
- {
- desc: "Socket client with a relative URL at /",
- relativeURLRoot: "/",
- server: testserver.StartSocketHttpServer,
- },
- {
- desc: "Socket client with relative URL at /gitlab",
- relativeURLRoot: "/gitlab",
- server: testserver.StartSocketHttpServer,
- },
- {
- desc: "Http client",
- server: testserver.StartHttpServer,
- },
- {
- desc: "Https client",
- caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
- server: func(t *testing.T, handlers []testserver.TestRequestHandler) string {
- return testserver.StartHttpsServer(t, handlers, "")
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- url := tc.server(t, buildRequests(t, tc.relativeURLRoot))
-
- secret := "sssh, it's a secret"
-
- httpClient, err := NewHTTPClientWithOpts(url, tc.relativeURLRoot, tc.caFile, "", false, 1, nil)
- require.NoError(t, err)
-
- client, err := NewGitlabNetClient("", "", secret, httpClient)
- require.NoError(t, err)
-
- testBrokenRequest(t, client)
- testSuccessfulGet(t, client)
- testSuccessfulPost(t, client)
- testMissing(t, client)
- testErrorMessage(t, client)
- testAuthenticationHeader(t, client)
- })
- }
-}
-
-func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
- t.Run("Successful get", func(t *testing.T) {
- response, err := client.Get(context.Background(), "/hello")
- require.NoError(t, err)
- require.NotNil(t, response)
-
- defer response.Body.Close()
-
- responseBody, err := io.ReadAll(response.Body)
- require.NoError(t, err)
- require.Equal(t, string(responseBody), "Hello")
- })
-}
-
-func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
- t.Run("Successful Post", func(t *testing.T) {
- data := map[string]string{"key": "value"}
-
- response, err := client.Post(context.Background(), "/post_endpoint", data)
- require.NoError(t, err)
- require.NotNil(t, response)
-
- defer response.Body.Close()
-
- responseBody, err := io.ReadAll(response.Body)
- require.NoError(t, err)
- require.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody))
- })
-}
-
-func testMissing(t *testing.T, client *GitlabNetClient) {
- t.Run("Missing error for GET", func(t *testing.T) {
- response, err := client.Get(context.Background(), "/missing")
- require.EqualError(t, err, "Internal API error (404)")
- require.Nil(t, response)
- })
-
- t.Run("Missing error for POST", func(t *testing.T) {
- response, err := client.Post(context.Background(), "/missing", map[string]string{})
- require.EqualError(t, err, "Internal API error (404)")
- require.Nil(t, response)
- })
-}
-
-func testErrorMessage(t *testing.T, client *GitlabNetClient) {
- t.Run("Error with message for GET", func(t *testing.T) {
- response, err := client.Get(context.Background(), "/error")
- require.EqualError(t, err, "Don't do that")
- require.Nil(t, response)
- })
-
- t.Run("Error with message for POST", func(t *testing.T) {
- response, err := client.Post(context.Background(), "/error", map[string]string{})
- require.EqualError(t, err, "Don't do that")
- require.Nil(t, response)
- })
-}
-
-func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
- t.Run("Broken request for GET", func(t *testing.T) {
- response, err := client.Get(context.Background(), "/broken")
- require.EqualError(t, err, "Internal API unreachable")
- require.Nil(t, response)
- })
-
- t.Run("Broken request for POST", func(t *testing.T) {
- response, err := client.Post(context.Background(), "/broken", map[string]string{})
- require.EqualError(t, err, "Internal API unreachable")
- require.Nil(t, response)
- })
-}
-
-func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
- t.Run("Authentication headers for GET", func(t *testing.T) {
- response, err := client.Get(context.Background(), "/auth")
- require.NoError(t, err)
- require.NotNil(t, response)
-
- defer response.Body.Close()
-
- responseBody, err := io.ReadAll(response.Body)
- require.NoError(t, err)
-
- header, err := base64.StdEncoding.DecodeString(string(responseBody))
- require.NoError(t, err)
- require.Equal(t, "sssh, it's a secret", string(header))
- })
-
- t.Run("Authentication headers for POST", func(t *testing.T) {
- response, err := client.Post(context.Background(), "/auth", map[string]string{})
- require.NoError(t, err)
- require.NotNil(t, response)
-
- defer response.Body.Close()
-
- responseBody, err := io.ReadAll(response.Body)
- require.NoError(t, err)
-
- header, err := base64.StdEncoding.DecodeString(string(responseBody))
- require.NoError(t, err)
- require.Equal(t, "sssh, it's a secret", string(header))
- })
-}
-
-func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestRequestHandler {
- requests := []testserver.TestRequestHandler{
- {
- Path: "/api/v4/internal/hello",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodGet, r.Method)
-
- fmt.Fprint(w, "Hello")
- },
- },
- {
- Path: "/api/v4/internal/post_endpoint",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodPost, r.Method)
-
- b, err := io.ReadAll(r.Body)
- defer r.Body.Close()
-
- require.NoError(t, err)
-
- fmt.Fprint(w, "Echo: "+string(b))
- },
- },
- {
- Path: "/api/v4/internal/auth",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprint(w, r.Header.Get(secretHeaderName))
- },
- },
- {
- Path: "/api/v4/internal/error",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusBadRequest)
- body := map[string]string{
- "message": "Don't do that",
- }
- json.NewEncoder(w).Encode(body)
- },
- },
- {
- Path: "/api/v4/internal/broken",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- panic("Broken")
- },
- },
- }
-
- relativeURLRoot = strings.Trim(relativeURLRoot, "/")
- if relativeURLRoot != "" {
- for i, r := range requests {
- requests[i].Path = fmt.Sprintf("/%s%s", relativeURLRoot, r.Path)
- }
- }
-
- return requests
-}
diff --git a/client/gitlabnet.go b/client/gitlabnet.go
deleted file mode 100644
index f71c110..0000000
--- a/client/gitlabnet.go
+++ /dev/null
@@ -1,161 +0,0 @@
-package client
-
-import (
- "bytes"
- "context"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "time"
-
- "gitlab.com/gitlab-org/labkit/log"
-)
-
-const (
- internalApiPath = "/api/v4/internal"
- secretHeaderName = "Gitlab-Shared-Secret"
- defaultUserAgent = "GitLab-Shell"
-)
-
-type ErrorResponse struct {
- Message string `json:"message"`
-}
-
-type GitlabNetClient struct {
- httpClient *HttpClient
- user string
- password string
- secret string
- userAgent string
-}
-
-func NewGitlabNetClient(
- user,
- password,
- secret string,
- httpClient *HttpClient,
-) (*GitlabNetClient, error) {
-
- if httpClient == nil {
- return nil, fmt.Errorf("Unsupported protocol")
- }
-
- return &GitlabNetClient{
- httpClient: httpClient,
- user: user,
- password: password,
- secret: secret,
- userAgent: defaultUserAgent,
- }, nil
-}
-
-// SetUserAgent overrides the default user agent for the User-Agent header field
-// for subsequent requests for the GitlabNetClient
-func (c *GitlabNetClient) SetUserAgent(ua string) {
- c.userAgent = ua
-}
-
-func normalizePath(path string) string {
- if !strings.HasPrefix(path, "/") {
- path = "/" + path
- }
-
- if !strings.HasPrefix(path, internalApiPath) {
- path = internalApiPath + path
- }
- return path
-}
-
-func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, error) {
- var jsonReader io.Reader
- if data != nil {
- jsonData, err := json.Marshal(data)
- if err != nil {
- return nil, err
- }
-
- jsonReader = bytes.NewReader(jsonData)
- }
-
- request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader)
- if err != nil {
- return nil, err
- }
-
- return request, nil
-}
-
-func parseError(resp *http.Response) error {
- if resp.StatusCode >= 200 && resp.StatusCode <= 399 {
- return nil
- }
- defer resp.Body.Close()
- parsedResponse := &ErrorResponse{}
-
- if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
- return fmt.Errorf("Internal API error (%v)", resp.StatusCode)
- } else {
- return fmt.Errorf(parsedResponse.Message)
- }
-
-}
-
-func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) {
- return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil)
-}
-
-func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) {
- return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data)
-}
-
-func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) {
- request, err := newRequest(ctx, method, c.httpClient.Host, path, data)
- if err != nil {
- return nil, err
- }
-
- user, password := c.user, c.password
- if user != "" && password != "" {
- request.SetBasicAuth(user, password)
- }
-
- encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.secret))
- request.Header.Set(secretHeaderName, encodedSecret)
-
- request.Header.Add("Content-Type", "application/json")
- request.Header.Add("User-Agent", c.userAgent)
- request.Close = true
-
- start := time.Now()
- response, err := c.httpClient.Do(request)
- fields := log.Fields{
- "method": method,
- "url": request.URL.String(),
- "duration_ms": time.Since(start) / time.Millisecond,
- }
- logger := log.WithContextFields(ctx, fields)
-
- if err != nil {
- logger.WithError(err).Error("Internal API unreachable")
- return nil, fmt.Errorf("Internal API unreachable")
- }
-
- if response != nil {
- logger = logger.WithField("status", response.StatusCode)
- }
- if err := parseError(response); err != nil {
- logger.WithError(err).Error("Internal API error")
- return nil, err
- }
-
- if response.ContentLength >= 0 {
- logger = logger.WithField("content_length_bytes", response.ContentLength)
- }
-
- logger.Info("Finished HTTP request")
-
- return response, nil
-}
diff --git a/client/httpclient.go b/client/httpclient.go
deleted file mode 100644
index 72238f8..0000000
--- a/client/httpclient.go
+++ /dev/null
@@ -1,189 +0,0 @@
-package client
-
-import (
- "context"
- "crypto/tls"
- "crypto/x509"
- "errors"
- "fmt"
- "net"
- "net/http"
- "os"
- "path/filepath"
- "strings"
- "time"
-
- "gitlab.com/gitlab-org/labkit/correlation"
- "gitlab.com/gitlab-org/labkit/log"
- "gitlab.com/gitlab-org/labkit/tracing"
-)
-
-const (
- socketBaseUrl = "http://unix"
- unixSocketProtocol = "http+unix://"
- httpProtocol = "http://"
- httpsProtocol = "https://"
- defaultReadTimeoutSeconds = 300
-)
-
-var (
- ErrCafileNotFound = errors.New("cafile not found")
-)
-
-type HttpClient struct {
- *http.Client
- Host string
-}
-
-type httpClientCfg struct {
- keyPath, certPath string
- caFile, caPath string
-}
-
-func (hcc httpClientCfg) HaveCertAndKey() bool { return hcc.keyPath != "" && hcc.certPath != "" }
-
-// HTTPClientOpt provides options for configuring an HttpClient
-type HTTPClientOpt func(*httpClientCfg)
-
-// WithClientCert will configure the HttpClient to provide client certificates
-// when connecting to a server.
-func WithClientCert(certPath, keyPath string) HTTPClientOpt {
- return func(hcc *httpClientCfg) {
- hcc.keyPath = keyPath
- hcc.certPath = certPath
- }
-}
-
-// Deprecated: use NewHTTPClientWithOpts - https://gitlab.com/gitlab-org/gitlab-shell/-/issues/484
-func NewHTTPClient(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64) *HttpClient {
- c, err := NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath, selfSignedCert, readTimeoutSeconds, nil)
- if err != nil {
- log.WithError(err).Error("new http client with opts")
- }
- return c
-}
-
-// NewHTTPClientWithOpts builds an HTTP client using the provided options
-func NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64, opts []HTTPClientOpt) (*HttpClient, error) {
- var transport *http.Transport
- var host string
- var err error
- if strings.HasPrefix(gitlabURL, unixSocketProtocol) {
- transport, host = buildSocketTransport(gitlabURL, gitlabRelativeURLRoot)
- } else if strings.HasPrefix(gitlabURL, httpProtocol) {
- transport, host = buildHttpTransport(gitlabURL)
- } else if strings.HasPrefix(gitlabURL, httpsProtocol) {
- if _, err := os.Stat(caFile); err != nil {
- if os.IsNotExist(err) {
- return nil, fmt.Errorf("cannot find cafile '%s': %w", caFile, ErrCafileNotFound)
- }
- return nil, err
- }
-
- hcc := &httpClientCfg{
- caFile: caFile,
- caPath: caPath,
- }
-
- for _, opt := range opts {
- opt(hcc)
- }
-
- transport, host, err = buildHttpsTransport(*hcc, selfSignedCert, gitlabURL)
- if err != nil {
- return nil, err
- }
- } else {
- return nil, errors.New("unknown GitLab URL prefix")
- }
-
- c := &http.Client{
- Transport: correlation.NewInstrumentedRoundTripper(tracing.NewRoundTripper(transport)),
- Timeout: readTimeout(readTimeoutSeconds),
- }
-
- client := &HttpClient{Client: c, Host: host}
-
- return client, nil
-}
-
-func buildSocketTransport(gitlabURL, gitlabRelativeURLRoot string) (*http.Transport, string) {
- socketPath := strings.TrimPrefix(gitlabURL, unixSocketProtocol)
-
- transport := &http.Transport{
- DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
- dialer := net.Dialer{}
- return dialer.DialContext(ctx, "unix", socketPath)
- },
- }
-
- host := socketBaseUrl
- gitlabRelativeURLRoot = strings.Trim(gitlabRelativeURLRoot, "/")
- if gitlabRelativeURLRoot != "" {
- host = host + "/" + gitlabRelativeURLRoot
- }
-
- return transport, host
-}
-
-func buildHttpsTransport(hcc httpClientCfg, selfSignedCert bool, gitlabURL string) (*http.Transport, string, error) {
- certPool, err := x509.SystemCertPool()
-
- if err != nil {
- certPool = x509.NewCertPool()
- }
-
- if hcc.caFile != "" {
- addCertToPool(certPool, hcc.caFile)
- }
-
- if hcc.caPath != "" {
- fis, _ := os.ReadDir(hcc.caPath)
- for _, fi := range fis {
- if fi.IsDir() {
- continue
- }
-
- addCertToPool(certPool, filepath.Join(hcc.caPath, fi.Name()))
- }
- }
- tlsConfig := &tls.Config{
- RootCAs: certPool,
- InsecureSkipVerify: selfSignedCert,
- MinVersion: tls.VersionTLS12,
- }
-
- if hcc.HaveCertAndKey() {
- cert, err := tls.LoadX509KeyPair(hcc.certPath, hcc.keyPath)
- if err != nil {
- return nil, "", err
- }
- tlsConfig.Certificates = []tls.Certificate{cert}
- tlsConfig.BuildNameToCertificate()
- }
-
- transport := &http.Transport{
- TLSClientConfig: tlsConfig,
- }
-
- return transport, gitlabURL, err
-}
-
-func addCertToPool(certPool *x509.CertPool, fileName string) {
- cert, err := os.ReadFile(fileName)
- if err == nil {
- certPool.AppendCertsFromPEM(cert)
- }
-}
-
-func buildHttpTransport(gitlabURL string) (*http.Transport, string) {
- return &http.Transport{}, gitlabURL
-}
-
-func readTimeout(timeoutSeconds uint64) time.Duration {
- if timeoutSeconds == 0 {
- timeoutSeconds = defaultReadTimeoutSeconds
- }
-
- return time.Duration(timeoutSeconds) * time.Second
-}
diff --git a/client/httpclient_test.go b/client/httpclient_test.go
deleted file mode 100644
index f7a6340..0000000
--- a/client/httpclient_test.go
+++ /dev/null
@@ -1,133 +0,0 @@
-package client
-
-import (
- "context"
- "encoding/base64"
- "fmt"
- "io"
- "net/http"
- "strings"
- "testing"
- "time"
-
- "github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
-)
-
-func TestReadTimeout(t *testing.T) {
- expectedSeconds := uint64(300)
-
- client, err := NewHTTPClientWithOpts("http://localhost:3000", "", "", "", false, expectedSeconds, nil)
- require.NoError(t, err)
-
- require.NotNil(t, client)
- require.Equal(t, time.Duration(expectedSeconds)*time.Second, client.Client.Timeout)
-}
-
-const (
- username = "basic_auth_user"
- password = "basic_auth_password"
-)
-
-func TestBasicAuthSettings(t *testing.T) {
- requests := []testserver.TestRequestHandler{
- {
- Path: "/api/v4/internal/get_endpoint",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodGet, r.Method)
-
- fmt.Fprint(w, r.Header.Get("Authorization"))
- },
- },
- {
- Path: "/api/v4/internal/post_endpoint",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodPost, r.Method)
-
- fmt.Fprint(w, r.Header.Get("Authorization"))
- },
- },
- }
-
- client := setup(t, username, password, requests)
-
- response, err := client.Get(context.Background(), "/get_endpoint")
- require.NoError(t, err)
- testBasicAuthHeaders(t, response)
-
- response, err = client.Post(context.Background(), "/post_endpoint", nil)
- require.NoError(t, err)
- testBasicAuthHeaders(t, response)
-}
-
-func testBasicAuthHeaders(t *testing.T, response *http.Response) {
- defer response.Body.Close()
-
- require.NotNil(t, response)
- responseBody, err := io.ReadAll(response.Body)
- require.NoError(t, err)
-
- headerParts := strings.Split(string(responseBody), " ")
- require.Equal(t, "Basic", headerParts[0])
-
- credentials, err := base64.StdEncoding.DecodeString(headerParts[1])
- require.NoError(t, err)
-
- require.Equal(t, username+":"+password, string(credentials))
-}
-
-func TestEmptyBasicAuthSettings(t *testing.T) {
- requests := []testserver.TestRequestHandler{
- {
- Path: "/api/v4/internal/empty_basic_auth",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, "", r.Header.Get("Authorization"))
- },
- },
- }
-
- client := setup(t, "", "", requests)
-
- _, err := client.Get(context.Background(), "/empty_basic_auth")
- require.NoError(t, err)
-}
-
-func TestRequestWithUserAgent(t *testing.T) {
- const gitalyUserAgent = "gitaly/13.5.0"
- requests := []testserver.TestRequestHandler{
- {
- Path: "/api/v4/internal/default_user_agent",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, defaultUserAgent, r.UserAgent())
- },
- },
- {
- Path: "/api/v4/internal/override_user_agent",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, gitalyUserAgent, r.UserAgent())
- },
- },
- }
-
- client := setup(t, "", "", requests)
-
- _, err := client.Get(context.Background(), "/default_user_agent")
- require.NoError(t, err)
-
- client.SetUserAgent(gitalyUserAgent)
- _, err = client.Get(context.Background(), "/override_user_agent")
- require.NoError(t, err)
-
-}
-
-func setup(t *testing.T, username, password string, requests []testserver.TestRequestHandler) *GitlabNetClient {
- url := testserver.StartHttpServer(t, requests)
-
- httpClient, err := NewHTTPClientWithOpts(url, "", "", "", false, 1, nil)
- require.NoError(t, err)
-
- client, err := NewGitlabNetClient(username, password, "", httpClient)
- require.NoError(t, err)
-
- return client
-}
diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go
deleted file mode 100644
index d2c2293..0000000
--- a/client/httpsclient_test.go
+++ /dev/null
@@ -1,133 +0,0 @@
-package client
-
-import (
- "context"
- "fmt"
- "io"
- "net/http"
- "path"
- "testing"
-
- "github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
- "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
-)
-
-//go:generate openssl req -newkey rsa:4096 -new -nodes -x509 -days 3650 -out ../internal/testhelper/testdata/testroot/certs/client/server.crt -keyout ../internal/testhelper/testdata/testroot/certs/client/key.pem -subj "/C=US/ST=California/L=San Francisco/O=GitLab/OU=GitLab-Shell/CN=localhost"
-func TestSuccessfulRequests(t *testing.T) {
- testCases := []struct {
- desc string
- caFile, caPath string
- selfSigned bool
- clientCAPath, clientCertPath, clientKeyPath string // used for TLS client certs
- }{
- {
- desc: "Valid CaFile",
- caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
- },
- {
- desc: "Valid CaPath",
- caPath: path.Join(testhelper.TestRoot, "certs/valid"),
- caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
- },
- {
- desc: "Invalid cert with self signed cert option enabled",
- caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
- selfSigned: true,
- },
- {
- desc: "Client certs with CA",
- caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
- // Run the command "go generate httpsclient_test.go" to
- // regenerate the following test fixtures:
- clientCAPath: path.Join(testhelper.TestRoot, "certs/client/server.crt"),
- clientCertPath: path.Join(testhelper.TestRoot, "certs/client/server.crt"),
- clientKeyPath: path.Join(testhelper.TestRoot, "certs/client/key.pem"),
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- client, err := setupWithRequests(t, tc.caFile, tc.caPath, tc.clientCAPath, tc.clientCertPath, tc.clientKeyPath, tc.selfSigned)
- require.NoError(t, err)
-
- response, err := client.Get(context.Background(), "/hello")
- require.NoError(t, err)
- require.NotNil(t, response)
-
- defer response.Body.Close()
-
- responseBody, err := io.ReadAll(response.Body)
- require.NoError(t, err)
- require.Equal(t, string(responseBody), "Hello")
- })
- }
-}
-
-func TestFailedRequests(t *testing.T) {
- testCases := []struct {
- desc string
- caFile string
- caPath string
- expectedError string
- }{
- {
- desc: "Invalid CaFile",
- caFile: path.Join(testhelper.TestRoot, "certs/invalid/server.crt"),
- expectedError: "Internal API unreachable",
- },
- {
- desc: "Invalid CaPath",
- caPath: path.Join(testhelper.TestRoot, "certs/invalid"),
- },
- {
- desc: "Empty config",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- client, err := setupWithRequests(t, tc.caFile, tc.caPath, "", "", "", false)
- if tc.caFile == "" {
- require.Error(t, err)
- require.ErrorIs(t, err, ErrCafileNotFound)
- } else {
- _, err = client.Get(context.Background(), "/hello")
- require.Error(t, err)
-
- require.Equal(t, err.Error(), tc.expectedError)
- }
- })
- }
-}
-
-func setupWithRequests(t *testing.T, caFile, caPath, clientCAPath, clientCertPath, clientKeyPath string, selfSigned bool) (*GitlabNetClient, error) {
- testhelper.PrepareTestRootDir(t)
-
- requests := []testserver.TestRequestHandler{
- {
- Path: "/api/v4/internal/hello",
- Handler: func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodGet, r.Method)
-
- fmt.Fprint(w, "Hello")
- },
- },
- }
-
- url := testserver.StartHttpsServer(t, requests, clientCAPath)
-
- var opts []HTTPClientOpt
- if clientCertPath != "" && clientKeyPath != "" {
- opts = append(opts, WithClientCert(clientCertPath, clientKeyPath))
- }
-
- httpClient, err := NewHTTPClientWithOpts(url, "", caFile, caPath, selfSigned, 1, opts)
- if err != nil {
- return nil, err
- }
-
- client, err := NewGitlabNetClient("", "", "", httpClient)
-
- return client, err
-}
diff --git a/client/testserver/gitalyserver.go b/client/testserver/gitalyserver.go
index 7159d16..92fbe86 100644
--- a/client/testserver/gitalyserver.go
+++ b/client/testserver/gitalyserver.go
@@ -13,7 +13,10 @@ import (
"google.golang.org/grpc/metadata"
)
-type TestGitalyServer struct{ ReceivedMD metadata.MD }
+type TestGitalyServer struct {
+ pb.UnimplementedSSHServiceServer
+ ReceivedMD metadata.MD
+}
func (s *TestGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error {
req, err := stream.Recv()
diff --git a/client/testserver/testserver.go b/client/testserver/testserver.go
index c263aa0..9a1509c 100644
--- a/client/testserver/testserver.go
+++ b/client/testserver/testserver.go
@@ -1,8 +1,6 @@
package testserver
import (
- "crypto/tls"
- "crypto/x509"
"io"
"log"
"net"
@@ -14,7 +12,6 @@ import (
"testing"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
)
var (
@@ -59,40 +56,6 @@ func StartHttpServer(t *testing.T, handlers []TestRequestHandler) string {
return server.URL
}
-func StartHttpsServer(t *testing.T, handlers []TestRequestHandler, clientCAPath string) string {
- t.Helper()
-
- crt := path.Join(testhelper.TestRoot, "certs/valid/server.crt")
- key := path.Join(testhelper.TestRoot, "certs/valid/server.key")
-
- server := httptest.NewUnstartedServer(buildHandler(handlers))
- cer, err := tls.LoadX509KeyPair(crt, key)
- require.NoError(t, err)
-
- server.TLS = &tls.Config{
- Certificates: []tls.Certificate{cer},
- MinVersion: tls.VersionTLS12,
- }
- server.TLS.BuildNameToCertificate()
-
- if clientCAPath != "" {
- caCert, err := os.ReadFile(clientCAPath)
- require.NoError(t, err)
-
- caCertPool := x509.NewCertPool()
- caCertPool.AppendCertsFromPEM(caCert)
-
- server.TLS.ClientCAs = caCertPool
- server.TLS.ClientAuth = tls.RequireAndVerifyClientCert
- }
-
- server.StartTLS()
-
- t.Cleanup(func() { server.Close() })
-
- return server.URL
-}
-
func buildHandler(handlers []TestRequestHandler) http.Handler {
h := http.NewServeMux()