diff options
Diffstat (limited to 'client')
-rw-r--r-- | client/client_test.go | 238 | ||||
-rw-r--r-- | client/gitlabnet.go | 161 | ||||
-rw-r--r-- | client/httpclient.go | 189 | ||||
-rw-r--r-- | client/httpclient_test.go | 133 | ||||
-rw-r--r-- | client/httpsclient_test.go | 133 | ||||
-rw-r--r-- | client/testserver/gitalyserver.go | 5 | ||||
-rw-r--r-- | client/testserver/testserver.go | 37 |
7 files changed, 4 insertions, 892 deletions
diff --git a/client/client_test.go b/client/client_test.go deleted file mode 100644 index 48681f7..0000000 --- a/client/client_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package client - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "path" - "strings" - "testing" - - "github.com/stretchr/testify/require" - - "gitlab.com/gitlab-org/gitlab-shell/client/testserver" - "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" -) - -func TestClients(t *testing.T) { - testhelper.PrepareTestRootDir(t) - - testCases := []struct { - desc string - relativeURLRoot string - caFile string - server func(*testing.T, []testserver.TestRequestHandler) string - }{ - { - desc: "Socket client", - server: testserver.StartSocketHttpServer, - }, - { - desc: "Socket client with a relative URL at /", - relativeURLRoot: "/", - server: testserver.StartSocketHttpServer, - }, - { - desc: "Socket client with relative URL at /gitlab", - relativeURLRoot: "/gitlab", - server: testserver.StartSocketHttpServer, - }, - { - desc: "Http client", - server: testserver.StartHttpServer, - }, - { - desc: "Https client", - caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"), - server: func(t *testing.T, handlers []testserver.TestRequestHandler) string { - return testserver.StartHttpsServer(t, handlers, "") - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - url := tc.server(t, buildRequests(t, tc.relativeURLRoot)) - - secret := "sssh, it's a secret" - - httpClient, err := NewHTTPClientWithOpts(url, tc.relativeURLRoot, tc.caFile, "", false, 1, nil) - require.NoError(t, err) - - client, err := NewGitlabNetClient("", "", secret, httpClient) - require.NoError(t, err) - - testBrokenRequest(t, client) - testSuccessfulGet(t, client) - testSuccessfulPost(t, client) - testMissing(t, client) - testErrorMessage(t, client) - testAuthenticationHeader(t, client) - }) - } -} - -func testSuccessfulGet(t *testing.T, client *GitlabNetClient) { - t.Run("Successful get", func(t *testing.T) { - response, err := client.Get(context.Background(), "/hello") - require.NoError(t, err) - require.NotNil(t, response) - - defer response.Body.Close() - - responseBody, err := io.ReadAll(response.Body) - require.NoError(t, err) - require.Equal(t, string(responseBody), "Hello") - }) -} - -func testSuccessfulPost(t *testing.T, client *GitlabNetClient) { - t.Run("Successful Post", func(t *testing.T) { - data := map[string]string{"key": "value"} - - response, err := client.Post(context.Background(), "/post_endpoint", data) - require.NoError(t, err) - require.NotNil(t, response) - - defer response.Body.Close() - - responseBody, err := io.ReadAll(response.Body) - require.NoError(t, err) - require.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody)) - }) -} - -func testMissing(t *testing.T, client *GitlabNetClient) { - t.Run("Missing error for GET", func(t *testing.T) { - response, err := client.Get(context.Background(), "/missing") - require.EqualError(t, err, "Internal API error (404)") - require.Nil(t, response) - }) - - t.Run("Missing error for POST", func(t *testing.T) { - response, err := client.Post(context.Background(), "/missing", map[string]string{}) - require.EqualError(t, err, "Internal API error (404)") - require.Nil(t, response) - }) -} - -func testErrorMessage(t *testing.T, client *GitlabNetClient) { - t.Run("Error with message for GET", func(t *testing.T) { - response, err := client.Get(context.Background(), "/error") - require.EqualError(t, err, "Don't do that") - require.Nil(t, response) - }) - - t.Run("Error with message for POST", func(t *testing.T) { - response, err := client.Post(context.Background(), "/error", map[string]string{}) - require.EqualError(t, err, "Don't do that") - require.Nil(t, response) - }) -} - -func testBrokenRequest(t *testing.T, client *GitlabNetClient) { - t.Run("Broken request for GET", func(t *testing.T) { - response, err := client.Get(context.Background(), "/broken") - require.EqualError(t, err, "Internal API unreachable") - require.Nil(t, response) - }) - - t.Run("Broken request for POST", func(t *testing.T) { - response, err := client.Post(context.Background(), "/broken", map[string]string{}) - require.EqualError(t, err, "Internal API unreachable") - require.Nil(t, response) - }) -} - -func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) { - t.Run("Authentication headers for GET", func(t *testing.T) { - response, err := client.Get(context.Background(), "/auth") - require.NoError(t, err) - require.NotNil(t, response) - - defer response.Body.Close() - - responseBody, err := io.ReadAll(response.Body) - require.NoError(t, err) - - header, err := base64.StdEncoding.DecodeString(string(responseBody)) - require.NoError(t, err) - require.Equal(t, "sssh, it's a secret", string(header)) - }) - - t.Run("Authentication headers for POST", func(t *testing.T) { - response, err := client.Post(context.Background(), "/auth", map[string]string{}) - require.NoError(t, err) - require.NotNil(t, response) - - defer response.Body.Close() - - responseBody, err := io.ReadAll(response.Body) - require.NoError(t, err) - - header, err := base64.StdEncoding.DecodeString(string(responseBody)) - require.NoError(t, err) - require.Equal(t, "sssh, it's a secret", string(header)) - }) -} - -func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestRequestHandler { - requests := []testserver.TestRequestHandler{ - { - Path: "/api/v4/internal/hello", - Handler: func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - - fmt.Fprint(w, "Hello") - }, - }, - { - Path: "/api/v4/internal/post_endpoint", - Handler: func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - - b, err := io.ReadAll(r.Body) - defer r.Body.Close() - - require.NoError(t, err) - - fmt.Fprint(w, "Echo: "+string(b)) - }, - }, - { - Path: "/api/v4/internal/auth", - Handler: func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, r.Header.Get(secretHeaderName)) - }, - }, - { - Path: "/api/v4/internal/error", - Handler: func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - body := map[string]string{ - "message": "Don't do that", - } - json.NewEncoder(w).Encode(body) - }, - }, - { - Path: "/api/v4/internal/broken", - Handler: func(w http.ResponseWriter, r *http.Request) { - panic("Broken") - }, - }, - } - - relativeURLRoot = strings.Trim(relativeURLRoot, "/") - if relativeURLRoot != "" { - for i, r := range requests { - requests[i].Path = fmt.Sprintf("/%s%s", relativeURLRoot, r.Path) - } - } - - return requests -} diff --git a/client/gitlabnet.go b/client/gitlabnet.go deleted file mode 100644 index f71c110..0000000 --- a/client/gitlabnet.go +++ /dev/null @@ -1,161 +0,0 @@ -package client - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "gitlab.com/gitlab-org/labkit/log" -) - -const ( - internalApiPath = "/api/v4/internal" - secretHeaderName = "Gitlab-Shared-Secret" - defaultUserAgent = "GitLab-Shell" -) - -type ErrorResponse struct { - Message string `json:"message"` -} - -type GitlabNetClient struct { - httpClient *HttpClient - user string - password string - secret string - userAgent string -} - -func NewGitlabNetClient( - user, - password, - secret string, - httpClient *HttpClient, -) (*GitlabNetClient, error) { - - if httpClient == nil { - return nil, fmt.Errorf("Unsupported protocol") - } - - return &GitlabNetClient{ - httpClient: httpClient, - user: user, - password: password, - secret: secret, - userAgent: defaultUserAgent, - }, nil -} - -// SetUserAgent overrides the default user agent for the User-Agent header field -// for subsequent requests for the GitlabNetClient -func (c *GitlabNetClient) SetUserAgent(ua string) { - c.userAgent = ua -} - -func normalizePath(path string) string { - if !strings.HasPrefix(path, "/") { - path = "/" + path - } - - if !strings.HasPrefix(path, internalApiPath) { - path = internalApiPath + path - } - return path -} - -func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, error) { - var jsonReader io.Reader - if data != nil { - jsonData, err := json.Marshal(data) - if err != nil { - return nil, err - } - - jsonReader = bytes.NewReader(jsonData) - } - - request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader) - if err != nil { - return nil, err - } - - return request, nil -} - -func parseError(resp *http.Response) error { - if resp.StatusCode >= 200 && resp.StatusCode <= 399 { - return nil - } - defer resp.Body.Close() - parsedResponse := &ErrorResponse{} - - if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { - return fmt.Errorf("Internal API error (%v)", resp.StatusCode) - } else { - return fmt.Errorf(parsedResponse.Message) - } - -} - -func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) { - return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil) -} - -func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) { - return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data) -} - -func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) { - request, err := newRequest(ctx, method, c.httpClient.Host, path, data) - if err != nil { - return nil, err - } - - user, password := c.user, c.password - if user != "" && password != "" { - request.SetBasicAuth(user, password) - } - - encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.secret)) - request.Header.Set(secretHeaderName, encodedSecret) - - request.Header.Add("Content-Type", "application/json") - request.Header.Add("User-Agent", c.userAgent) - request.Close = true - - start := time.Now() - response, err := c.httpClient.Do(request) - fields := log.Fields{ - "method": method, - "url": request.URL.String(), - "duration_ms": time.Since(start) / time.Millisecond, - } - logger := log.WithContextFields(ctx, fields) - - if err != nil { - logger.WithError(err).Error("Internal API unreachable") - return nil, fmt.Errorf("Internal API unreachable") - } - - if response != nil { - logger = logger.WithField("status", response.StatusCode) - } - if err := parseError(response); err != nil { - logger.WithError(err).Error("Internal API error") - return nil, err - } - - if response.ContentLength >= 0 { - logger = logger.WithField("content_length_bytes", response.ContentLength) - } - - logger.Info("Finished HTTP request") - - return response, nil -} diff --git a/client/httpclient.go b/client/httpclient.go deleted file mode 100644 index 72238f8..0000000 --- a/client/httpclient.go +++ /dev/null @@ -1,189 +0,0 @@ -package client - -import ( - "context" - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "net" - "net/http" - "os" - "path/filepath" - "strings" - "time" - - "gitlab.com/gitlab-org/labkit/correlation" - "gitlab.com/gitlab-org/labkit/log" - "gitlab.com/gitlab-org/labkit/tracing" -) - -const ( - socketBaseUrl = "http://unix" - unixSocketProtocol = "http+unix://" - httpProtocol = "http://" - httpsProtocol = "https://" - defaultReadTimeoutSeconds = 300 -) - -var ( - ErrCafileNotFound = errors.New("cafile not found") -) - -type HttpClient struct { - *http.Client - Host string -} - -type httpClientCfg struct { - keyPath, certPath string - caFile, caPath string -} - -func (hcc httpClientCfg) HaveCertAndKey() bool { return hcc.keyPath != "" && hcc.certPath != "" } - -// HTTPClientOpt provides options for configuring an HttpClient -type HTTPClientOpt func(*httpClientCfg) - -// WithClientCert will configure the HttpClient to provide client certificates -// when connecting to a server. -func WithClientCert(certPath, keyPath string) HTTPClientOpt { - return func(hcc *httpClientCfg) { - hcc.keyPath = keyPath - hcc.certPath = certPath - } -} - -// Deprecated: use NewHTTPClientWithOpts - https://gitlab.com/gitlab-org/gitlab-shell/-/issues/484 -func NewHTTPClient(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64) *HttpClient { - c, err := NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath, selfSignedCert, readTimeoutSeconds, nil) - if err != nil { - log.WithError(err).Error("new http client with opts") - } - return c -} - -// NewHTTPClientWithOpts builds an HTTP client using the provided options -func NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64, opts []HTTPClientOpt) (*HttpClient, error) { - var transport *http.Transport - var host string - var err error - if strings.HasPrefix(gitlabURL, unixSocketProtocol) { - transport, host = buildSocketTransport(gitlabURL, gitlabRelativeURLRoot) - } else if strings.HasPrefix(gitlabURL, httpProtocol) { - transport, host = buildHttpTransport(gitlabURL) - } else if strings.HasPrefix(gitlabURL, httpsProtocol) { - if _, err := os.Stat(caFile); err != nil { - if os.IsNotExist(err) { - return nil, fmt.Errorf("cannot find cafile '%s': %w", caFile, ErrCafileNotFound) - } - return nil, err - } - - hcc := &httpClientCfg{ - caFile: caFile, - caPath: caPath, - } - - for _, opt := range opts { - opt(hcc) - } - - transport, host, err = buildHttpsTransport(*hcc, selfSignedCert, gitlabURL) - if err != nil { - return nil, err - } - } else { - return nil, errors.New("unknown GitLab URL prefix") - } - - c := &http.Client{ - Transport: correlation.NewInstrumentedRoundTripper(tracing.NewRoundTripper(transport)), - Timeout: readTimeout(readTimeoutSeconds), - } - - client := &HttpClient{Client: c, Host: host} - - return client, nil -} - -func buildSocketTransport(gitlabURL, gitlabRelativeURLRoot string) (*http.Transport, string) { - socketPath := strings.TrimPrefix(gitlabURL, unixSocketProtocol) - - transport := &http.Transport{ - DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { - dialer := net.Dialer{} - return dialer.DialContext(ctx, "unix", socketPath) - }, - } - - host := socketBaseUrl - gitlabRelativeURLRoot = strings.Trim(gitlabRelativeURLRoot, "/") - if gitlabRelativeURLRoot != "" { - host = host + "/" + gitlabRelativeURLRoot - } - - return transport, host -} - -func buildHttpsTransport(hcc httpClientCfg, selfSignedCert bool, gitlabURL string) (*http.Transport, string, error) { - certPool, err := x509.SystemCertPool() - - if err != nil { - certPool = x509.NewCertPool() - } - - if hcc.caFile != "" { - addCertToPool(certPool, hcc.caFile) - } - - if hcc.caPath != "" { - fis, _ := os.ReadDir(hcc.caPath) - for _, fi := range fis { - if fi.IsDir() { - continue - } - - addCertToPool(certPool, filepath.Join(hcc.caPath, fi.Name())) - } - } - tlsConfig := &tls.Config{ - RootCAs: certPool, - InsecureSkipVerify: selfSignedCert, - MinVersion: tls.VersionTLS12, - } - - if hcc.HaveCertAndKey() { - cert, err := tls.LoadX509KeyPair(hcc.certPath, hcc.keyPath) - if err != nil { - return nil, "", err - } - tlsConfig.Certificates = []tls.Certificate{cert} - tlsConfig.BuildNameToCertificate() - } - - transport := &http.Transport{ - TLSClientConfig: tlsConfig, - } - - return transport, gitlabURL, err -} - -func addCertToPool(certPool *x509.CertPool, fileName string) { - cert, err := os.ReadFile(fileName) - if err == nil { - certPool.AppendCertsFromPEM(cert) - } -} - -func buildHttpTransport(gitlabURL string) (*http.Transport, string) { - return &http.Transport{}, gitlabURL -} - -func readTimeout(timeoutSeconds uint64) time.Duration { - if timeoutSeconds == 0 { - timeoutSeconds = defaultReadTimeoutSeconds - } - - return time.Duration(timeoutSeconds) * time.Second -} diff --git a/client/httpclient_test.go b/client/httpclient_test.go deleted file mode 100644 index f7a6340..0000000 --- a/client/httpclient_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package client - -import ( - "context" - "encoding/base64" - "fmt" - "io" - "net/http" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitlab-shell/client/testserver" -) - -func TestReadTimeout(t *testing.T) { - expectedSeconds := uint64(300) - - client, err := NewHTTPClientWithOpts("http://localhost:3000", "", "", "", false, expectedSeconds, nil) - require.NoError(t, err) - - require.NotNil(t, client) - require.Equal(t, time.Duration(expectedSeconds)*time.Second, client.Client.Timeout) -} - -const ( - username = "basic_auth_user" - password = "basic_auth_password" -) - -func TestBasicAuthSettings(t *testing.T) { - requests := []testserver.TestRequestHandler{ - { - Path: "/api/v4/internal/get_endpoint", - Handler: func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - - fmt.Fprint(w, r.Header.Get("Authorization")) - }, - }, - { - Path: "/api/v4/internal/post_endpoint", - Handler: func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - - fmt.Fprint(w, r.Header.Get("Authorization")) - }, - }, - } - - client := setup(t, username, password, requests) - - response, err := client.Get(context.Background(), "/get_endpoint") - require.NoError(t, err) - testBasicAuthHeaders(t, response) - - response, err = client.Post(context.Background(), "/post_endpoint", nil) - require.NoError(t, err) - testBasicAuthHeaders(t, response) -} - -func testBasicAuthHeaders(t *testing.T, response *http.Response) { - defer response.Body.Close() - - require.NotNil(t, response) - responseBody, err := io.ReadAll(response.Body) - require.NoError(t, err) - - headerParts := strings.Split(string(responseBody), " ") - require.Equal(t, "Basic", headerParts[0]) - - credentials, err := base64.StdEncoding.DecodeString(headerParts[1]) - require.NoError(t, err) - - require.Equal(t, username+":"+password, string(credentials)) -} - -func TestEmptyBasicAuthSettings(t *testing.T) { - requests := []testserver.TestRequestHandler{ - { - Path: "/api/v4/internal/empty_basic_auth", - Handler: func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "", r.Header.Get("Authorization")) - }, - }, - } - - client := setup(t, "", "", requests) - - _, err := client.Get(context.Background(), "/empty_basic_auth") - require.NoError(t, err) -} - -func TestRequestWithUserAgent(t *testing.T) { - const gitalyUserAgent = "gitaly/13.5.0" - requests := []testserver.TestRequestHandler{ - { - Path: "/api/v4/internal/default_user_agent", - Handler: func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, defaultUserAgent, r.UserAgent()) - }, - }, - { - Path: "/api/v4/internal/override_user_agent", - Handler: func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, gitalyUserAgent, r.UserAgent()) - }, - }, - } - - client := setup(t, "", "", requests) - - _, err := client.Get(context.Background(), "/default_user_agent") - require.NoError(t, err) - - client.SetUserAgent(gitalyUserAgent) - _, err = client.Get(context.Background(), "/override_user_agent") - require.NoError(t, err) - -} - -func setup(t *testing.T, username, password string, requests []testserver.TestRequestHandler) *GitlabNetClient { - url := testserver.StartHttpServer(t, requests) - - httpClient, err := NewHTTPClientWithOpts(url, "", "", "", false, 1, nil) - require.NoError(t, err) - - client, err := NewGitlabNetClient(username, password, "", httpClient) - require.NoError(t, err) - - return client -} diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go deleted file mode 100644 index d2c2293..0000000 --- a/client/httpsclient_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package client - -import ( - "context" - "fmt" - "io" - "net/http" - "path" - "testing" - - "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitlab-shell/client/testserver" - "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" -) - -//go:generate openssl req -newkey rsa:4096 -new -nodes -x509 -days 3650 -out ../internal/testhelper/testdata/testroot/certs/client/server.crt -keyout ../internal/testhelper/testdata/testroot/certs/client/key.pem -subj "/C=US/ST=California/L=San Francisco/O=GitLab/OU=GitLab-Shell/CN=localhost" -func TestSuccessfulRequests(t *testing.T) { - testCases := []struct { - desc string - caFile, caPath string - selfSigned bool - clientCAPath, clientCertPath, clientKeyPath string // used for TLS client certs - }{ - { - desc: "Valid CaFile", - caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"), - }, - { - desc: "Valid CaPath", - caPath: path.Join(testhelper.TestRoot, "certs/valid"), - caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"), - }, - { - desc: "Invalid cert with self signed cert option enabled", - caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"), - selfSigned: true, - }, - { - desc: "Client certs with CA", - caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"), - // Run the command "go generate httpsclient_test.go" to - // regenerate the following test fixtures: - clientCAPath: path.Join(testhelper.TestRoot, "certs/client/server.crt"), - clientCertPath: path.Join(testhelper.TestRoot, "certs/client/server.crt"), - clientKeyPath: path.Join(testhelper.TestRoot, "certs/client/key.pem"), - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - client, err := setupWithRequests(t, tc.caFile, tc.caPath, tc.clientCAPath, tc.clientCertPath, tc.clientKeyPath, tc.selfSigned) - require.NoError(t, err) - - response, err := client.Get(context.Background(), "/hello") - require.NoError(t, err) - require.NotNil(t, response) - - defer response.Body.Close() - - responseBody, err := io.ReadAll(response.Body) - require.NoError(t, err) - require.Equal(t, string(responseBody), "Hello") - }) - } -} - -func TestFailedRequests(t *testing.T) { - testCases := []struct { - desc string - caFile string - caPath string - expectedError string - }{ - { - desc: "Invalid CaFile", - caFile: path.Join(testhelper.TestRoot, "certs/invalid/server.crt"), - expectedError: "Internal API unreachable", - }, - { - desc: "Invalid CaPath", - caPath: path.Join(testhelper.TestRoot, "certs/invalid"), - }, - { - desc: "Empty config", - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - client, err := setupWithRequests(t, tc.caFile, tc.caPath, "", "", "", false) - if tc.caFile == "" { - require.Error(t, err) - require.ErrorIs(t, err, ErrCafileNotFound) - } else { - _, err = client.Get(context.Background(), "/hello") - require.Error(t, err) - - require.Equal(t, err.Error(), tc.expectedError) - } - }) - } -} - -func setupWithRequests(t *testing.T, caFile, caPath, clientCAPath, clientCertPath, clientKeyPath string, selfSigned bool) (*GitlabNetClient, error) { - testhelper.PrepareTestRootDir(t) - - requests := []testserver.TestRequestHandler{ - { - Path: "/api/v4/internal/hello", - Handler: func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - - fmt.Fprint(w, "Hello") - }, - }, - } - - url := testserver.StartHttpsServer(t, requests, clientCAPath) - - var opts []HTTPClientOpt - if clientCertPath != "" && clientKeyPath != "" { - opts = append(opts, WithClientCert(clientCertPath, clientKeyPath)) - } - - httpClient, err := NewHTTPClientWithOpts(url, "", caFile, caPath, selfSigned, 1, opts) - if err != nil { - return nil, err - } - - client, err := NewGitlabNetClient("", "", "", httpClient) - - return client, err -} diff --git a/client/testserver/gitalyserver.go b/client/testserver/gitalyserver.go index 7159d16..92fbe86 100644 --- a/client/testserver/gitalyserver.go +++ b/client/testserver/gitalyserver.go @@ -13,7 +13,10 @@ import ( "google.golang.org/grpc/metadata" ) -type TestGitalyServer struct{ ReceivedMD metadata.MD } +type TestGitalyServer struct { + pb.UnimplementedSSHServiceServer + ReceivedMD metadata.MD +} func (s *TestGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error { req, err := stream.Recv() diff --git a/client/testserver/testserver.go b/client/testserver/testserver.go index c263aa0..9a1509c 100644 --- a/client/testserver/testserver.go +++ b/client/testserver/testserver.go @@ -1,8 +1,6 @@ package testserver import ( - "crypto/tls" - "crypto/x509" "io" "log" "net" @@ -14,7 +12,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" ) var ( @@ -59,40 +56,6 @@ func StartHttpServer(t *testing.T, handlers []TestRequestHandler) string { return server.URL } -func StartHttpsServer(t *testing.T, handlers []TestRequestHandler, clientCAPath string) string { - t.Helper() - - crt := path.Join(testhelper.TestRoot, "certs/valid/server.crt") - key := path.Join(testhelper.TestRoot, "certs/valid/server.key") - - server := httptest.NewUnstartedServer(buildHandler(handlers)) - cer, err := tls.LoadX509KeyPair(crt, key) - require.NoError(t, err) - - server.TLS = &tls.Config{ - Certificates: []tls.Certificate{cer}, - MinVersion: tls.VersionTLS12, - } - server.TLS.BuildNameToCertificate() - - if clientCAPath != "" { - caCert, err := os.ReadFile(clientCAPath) - require.NoError(t, err) - - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - - server.TLS.ClientCAs = caCertPool - server.TLS.ClientAuth = tls.RequireAndVerifyClientCert - } - - server.StartTLS() - - t.Cleanup(func() { server.Close() }) - - return server.URL -} - func buildHandler(handlers []TestRequestHandler) http.Handler { h := http.NewServeMux() |