summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
authorStan Hu <stanhu@gmail.com>2020-09-19 03:34:49 -0700
committerStan Hu <stanhu@gmail.com>2020-09-20 21:40:40 -0700
commita487572a904cc149840488eefdfe121173d8bcb5 (patch)
treed17cf7bff45492a587027851bb6e0bcb493cff58 /client
parentf100e7e83943b3bb5db232f5bf79a616fdba88f1 (diff)
downloadgitlab-shell-a487572a904cc149840488eefdfe121173d8bcb5.tar.gz
Make it possible to propagate correlation ID across processes
Previously, gitlab-shell did not pass a context through the application. Correlation IDs were generated down the call stack instead of passed around from the start execution. This has several potential downsides: 1. It's easier for programming mistakes to be made in future that lead to multiple correlation IDs being generated for a single request. 2. Correlation IDs cannot be passed in from upstream requests 3. Other advantages of context passing, such as distributed tracing is not possible. This commit changes the behavior: 1. Extract the correlation ID from the environment at the start of the application. 2. If no correlation ID exists, generate a random one. 3. Pass the correlation ID to the GitLabNet API requests. This change also enables other clients of GitLabNet (e.g. Gitaly) to pass along the correlation ID in the internal API requests (https://gitlab.com/gitlab-org/gitaly/-/issues/2725). Fixes https://gitlab.com/gitlab-org/gitlab-shell/-/issues/474
Diffstat (limited to 'client')
-rw-r--r--client/client_test.go21
-rw-r--r--client/gitlabnet.go28
-rw-r--r--client/httpclient_test.go7
-rw-r--r--client/httpsclient_test.go5
4 files changed, 29 insertions, 32 deletions
diff --git a/client/client_test.go b/client/client_test.go
index e92093a..e0650b2 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -78,7 +79,7 @@ func TestClients(t *testing.T) {
func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
t.Run("Successful get", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/hello")
+ response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
@@ -104,7 +105,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
hook := testhelper.SetupLogger()
data := map[string]string{"key": "value"}
- response, err := client.Post("/post_endpoint", data)
+ response, err := client.Post(context.Background(), "/post_endpoint", data)
require.NoError(t, err)
require.NotNil(t, response)
@@ -128,7 +129,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/missing")
+ response, err := client.Get(context.Background(), "/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
@@ -144,7 +145,7 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Post("/missing", map[string]string{})
+ response, err := client.Post(context.Background(), "/missing", map[string]string{})
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
@@ -161,13 +162,13 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
func testErrorMessage(t *testing.T, client *GitlabNetClient) {
t.Run("Error with message for GET", func(t *testing.T) {
- response, err := client.Get("/error")
+ response, err := client.Get(context.Background(), "/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
t.Run("Error with message for POST", func(t *testing.T) {
- response, err := client.Post("/error", map[string]string{})
+ response, err := client.Post(context.Background(), "/error", map[string]string{})
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
@@ -177,7 +178,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/broken")
+ response, err := client.Get(context.Background(), "/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
@@ -194,7 +195,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Post("/broken", map[string]string{})
+ response, err := client.Post(context.Background(), "/broken", map[string]string{})
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
@@ -211,7 +212,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
t.Run("Authentication headers for GET", func(t *testing.T) {
- response, err := client.Get("/auth")
+ response, err := client.Get(context.Background(), "/auth")
require.NoError(t, err)
require.NotNil(t, response)
@@ -226,7 +227,7 @@ func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
})
t.Run("Authentication headers for POST", func(t *testing.T) {
- response, err := client.Post("/auth", map[string]string{})
+ response, err := client.Post(context.Background(), "/auth", map[string]string{})
require.NoError(t, err)
require.NotNil(t, response)
diff --git a/client/gitlabnet.go b/client/gitlabnet.go
index 0657ca0..b908d04 100644
--- a/client/gitlabnet.go
+++ b/client/gitlabnet.go
@@ -11,8 +11,9 @@ import (
"strings"
"time"
- log "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/correlation"
+
+ log "github.com/sirupsen/logrus"
)
const (
@@ -59,7 +60,7 @@ func normalizePath(path string) string {
return path
}
-func newRequest(method, host, path string, data interface{}) (*http.Request, string, error) {
+func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, string, error) {
var jsonReader io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
@@ -70,20 +71,13 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, str
jsonReader = bytes.NewReader(jsonData)
}
- correlationID, err := correlation.RandomID()
- ctx := context.Background()
-
- if err != nil {
- log.WithError(err).Warn("unable to generate correlation ID")
- } else {
- ctx = correlation.ContextWithCorrelation(ctx, correlationID)
- }
-
request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader)
if err != nil {
return nil, "", err
}
+ correlationID := correlation.ExtractFromContext(ctx)
+
return request, correlationID, nil
}
@@ -102,16 +96,16 @@ func parseError(resp *http.Response) error {
}
-func (c *GitlabNetClient) Get(path string) (*http.Response, error) {
- return c.DoRequest(http.MethodGet, normalizePath(path), nil)
+func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) {
+ return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil)
}
-func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) {
- return c.DoRequest(http.MethodPost, normalizePath(path), data)
+func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) {
+ return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data)
}
-func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) {
- request, correlationID, err := newRequest(method, c.httpClient.Host, path, data)
+func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) {
+ request, correlationID, err := newRequest(ctx, method, c.httpClient.Host, path, data)
if err != nil {
return nil, err
}
diff --git a/client/httpclient_test.go b/client/httpclient_test.go
index fce0cd5..97e1384 100644
--- a/client/httpclient_test.go
+++ b/client/httpclient_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"encoding/base64"
"fmt"
"io/ioutil"
@@ -51,11 +52,11 @@ func TestBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, username, password, requests)
defer cleanup()
- response, err := client.Get("/get_endpoint")
+ response, err := client.Get(context.Background(), "/get_endpoint")
require.NoError(t, err)
testBasicAuthHeaders(t, response)
- response, err = client.Post("/post_endpoint", nil)
+ response, err = client.Post(context.Background(), "/post_endpoint", nil)
require.NoError(t, err)
testBasicAuthHeaders(t, response)
}
@@ -89,7 +90,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, "", "", requests)
defer cleanup()
- _, err := client.Get("/empty_basic_auth")
+ _, err := client.Get(context.Background(), "/empty_basic_auth")
require.NoError(t, err)
}
diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go
index 1c7435f..0cf77e3 100644
--- a/client/httpsclient_test.go
+++ b/client/httpsclient_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"fmt"
"io/ioutil"
"net/http"
@@ -43,7 +44,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned)
defer cleanup()
- response, err := client.Get("/hello")
+ response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
@@ -80,7 +81,7 @@ func TestFailedRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false)
defer cleanup()
- _, err := client.Get("/hello")
+ _, err := client.Get(context.Background(), "/hello")
require.Error(t, err)
assert.Equal(t, err.Error(), "Internal API unreachable")