diff options
Diffstat (limited to 'internal/command')
29 files changed, 200 insertions, 77 deletions
diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go index 7554761..736aeed 100644 --- a/internal/command/authorizedkeys/authorized_keys.go +++ b/internal/command/authorizedkeys/authorized_keys.go @@ -1,6 +1,7 @@ package authorizedkeys import ( + "context" "fmt" "strconv" @@ -17,7 +18,7 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { // Do and return nothing when the expected and actual user don't match. // This can happen when the user in sshd_config doesn't match the user // trying to login. When nothing is printed, the user will be denied access. @@ -27,15 +28,15 @@ func (c *Command) Execute() error { return nil } - if err := c.printKeyLine(); err != nil { + if err := c.printKeyLine(ctx); err != nil { return err } return nil } -func (c *Command) printKeyLine() error { - response, err := c.getAuthorizedKey() +func (c *Command) printKeyLine(ctx context.Context) error { + response, err := c.getAuthorizedKey(ctx) if err != nil { fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key)) return nil @@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error { return nil } -func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) { +func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) { client, err := authorizedkeys.NewClient(c.Config) if err != nil { return nil, err } - return client.GetByKey(c.Args.Key) + return client.GetByKey(ctx, c.Args.Key) } diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go index e12f4fa..f15c34d 100644 --- a/internal/command/authorizedkeys/authorized_keys_test.go +++ b/internal/command/authorizedkeys/authorized_keys_test.go @@ -2,6 +2,7 @@ package authorizedkeys import ( "bytes" + "context" "encoding/json" "net/http" "testing" @@ -97,7 +98,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go index ab5f2f8..44f6c47 100644 --- a/internal/command/authorizedprincipals/authorized_principals.go +++ b/internal/command/authorizedprincipals/authorized_principals.go @@ -1,6 +1,7 @@ package authorizedprincipals import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -15,7 +16,7 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { if err := c.printPrincipalLines(); err != nil { return err } diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go index f11dd0f..ec97b65 100644 --- a/internal/command/authorizedprincipals/authorized_principals_test.go +++ b/internal/command/authorizedprincipals/authorized_principals_test.go @@ -2,6 +2,7 @@ package authorizedprincipals import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -54,7 +55,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/command.go b/internal/command/command.go index 283b4a1..c69219b 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -1,6 +1,8 @@ package command import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -16,10 +18,13 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/tracing" ) type Command interface { - Execute() error + Execute(ctx context.Context) error } func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) { @@ -35,6 +40,28 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re return nil, disallowedcommand.Error } +// ContextWithCorrelationID() will always return a background Context +// with a correlation ID. It will first attempt to extract the ID from +// an environment variable. If is not available, a random one will be +// generated. +func ContextWithCorrelationID() (context.Context, func()) { + ctx, finished := tracing.ExtractFromEnv(context.Background()) + defer finished() + + correlationID := correlation.ExtractFromContext(ctx) + if correlationID == "" { + correlationID, err := correlation.RandomID() + if err != nil { + log.WithError(err).Warn("unable to generate correlation ID") + } else { + log.Info("generated random correlation ID") + ctx = correlation.ContextWithCorrelation(ctx, correlationID) + } + } + + return ctx, finished +} + func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command { switch e.Name { case executable.GitlabShell: diff --git a/internal/command/command_test.go b/internal/command/command_test.go index db55e7d..9160abf 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -2,6 +2,7 @@ package command import ( "errors" + "os" "testing" "github.com/stretchr/testify/require" @@ -20,6 +21,7 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" + "gitlab.com/gitlab-org/labkit/correlation" ) var ( @@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) { }) } } + +func TestContextWithCorrelationID(t *testing.T) { + testCases := []struct { + name string + additionalEnv map[string]string + expectedCorrelationID string + }{ + { + name: "no CORRELATION_ID in environment", + }, + { + name: "CORRELATION_ID in environment", + additionalEnv: map[string]string{ + "CORRELATION_ID": "abc123", + }, + expectedCorrelationID: "abc123", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resetEnvironment := addAdditionalEnv(tc.additionalEnv) + defer resetEnvironment() + + ctx, finished := ContextWithCorrelationID() + require.NotNil(t, ctx, "ctx is nil") + require.NotNil(t, finished, "finished is nil") + correlationID := correlation.ExtractFromContext(ctx) + require.NotEmpty(t, correlationID) + + if tc.expectedCorrelationID != "" { + require.Equal(t, tc.expectedCorrelationID, correlationID) + } + defer finished() + }) + } +} + +// addAdditionalEnv will configure additional environment values +// and return a deferrable function to reset the environment to +// it's original state after the test +func addAdditionalEnv(envMap map[string]string) func() { + prevValues := map[string]string{} + unsetValues := []string{} + for k, v := range envMap { + value, exists := os.LookupEnv(k) + if exists { + prevValues[k] = value + } else { + unsetValues = append(unsetValues, k) + } + os.Setenv(k, v) + } + + return func() { + for k, v := range prevValues { + os.Setenv(k, v) + } + + for _, k := range unsetValues { + os.Unsetenv(k) + } + + } +} diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go index 3aa7456..822be32 100644 --- a/internal/command/discover/discover.go +++ b/internal/command/discover/discover.go @@ -1,6 +1,7 @@ package discover import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -15,8 +16,8 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { - response, err := c.getUserInfo() +func (c *Command) Execute(ctx context.Context) error { + response, err := c.getUserInfo(ctx) if err != nil { return fmt.Errorf("Failed to get username: %v", err) } @@ -30,11 +31,11 @@ func (c *Command) Execute() error { return nil } -func (c *Command) getUserInfo() (*discover.Response, error) { +func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) { client, err := discover.NewClient(c.Config) if err != nil { return nil, err } - return client.GetByCommandArgs(c.Args) + return client.GetByCommandArgs(ctx, c.Args) } diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go index 8edbcb9..5431410 100644 --- a/internal/command/discover/discover_test.go +++ b/internal/command/discover/discover_test.go @@ -2,6 +2,7 @@ package discover import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -83,7 +84,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) @@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, tc.expectedError) diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go index bbc73dc..b04eb0d 100644 --- a/internal/command/healthcheck/healthcheck.go +++ b/internal/command/healthcheck/healthcheck.go @@ -1,6 +1,7 @@ package healthcheck import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" @@ -18,8 +19,8 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { - response, err := c.runCheck() +func (c *Command) Execute(ctx context.Context) error { + response, err := c.runCheck(ctx) if err != nil { return fmt.Errorf("%v: FAILED - %v", apiMessage, err) } @@ -34,13 +35,13 @@ func (c *Command) Execute() error { return nil } -func (c *Command) runCheck() (*healthcheck.Response, error) { +func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) { client, err := healthcheck.NewClient(c.Config) if err != nil { return nil, err } - response, err := client.Check() + response, err := client.Check(ctx) if err != nil { return nil, err } diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go index 7479bcb..d05e563 100644 --- a/internal/command/healthcheck/healthcheck_test.go +++ b/internal/command/healthcheck/healthcheck_test.go @@ -2,6 +2,7 @@ package healthcheck import ( "bytes" + "context" "encoding/json" "net/http" "testing" @@ -53,7 +54,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String()) @@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Error(t, err, "Redis available via internal API: FAILED") require.Equal(t, "Internal API available: OK\n", buffer.String()) } @@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)") } diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go index 2aaac2a..dab69ab 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate.go +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -1,6 +1,7 @@ package lfsauthenticate import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -34,7 +35,7 @@ type Payload struct { ExpiresIn int `json:"expires_in,omitempty"` } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) < 3 { return disallowedcommand.Error @@ -49,12 +50,12 @@ func (c *Command) Execute() error { return err } - accessResponse, err := c.verifyAccess(action, repo) + accessResponse, err := c.verifyAccess(ctx, action, repo) if err != nil { return err } - payload, err := c.authenticate(operation, repo, accessResponse.UserId) + payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId) if err != nil { // return nothing just like Ruby's GitlabShell#lfs_authenticate does return nil @@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) { return action, nil } -func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(action, repo) + return cmd.Verify(ctx, action, repo) } -func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) { +func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) { client, err := lfsauthenticate.NewClient(c.Config, c.Args) if err != nil { return nil, err } - response, err := client.Authenticate(operation, repo, userId) + response, err := client.Authenticate(ctx, operation, repo, userId) if err != nil { return nil, err } diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go index a1c7aec..55998ab 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate_test.go +++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go @@ -2,6 +2,7 @@ package lfsauthenticate import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -54,7 +55,7 @@ func TestFailedRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Error(t, err) require.Equal(t, tc.expectedOutput, err.Error()) @@ -146,7 +147,7 @@ func TestLfsAuthenticateRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go index b283890..6f3d03e 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken.go +++ b/internal/command/personalaccesstoken/personalaccesstoken.go @@ -1,6 +1,7 @@ package personalaccesstoken import ( + "context" "errors" "fmt" "strconv" @@ -31,13 +32,13 @@ type tokenArgs struct { ExpiresDate string // Calculated, a TTL is passed from command-line. } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { err := c.parseTokenArgs() if err != nil { return err } - response, err := c.getPersonalAccessToken() + response, err := c.getPersonalAccessToken(ctx) if err != nil { return err } @@ -76,11 +77,11 @@ func (c *Command) parseTokenArgs() error { return nil } -func (c *Command) getPersonalAccessToken() (*personalaccesstoken.Response, error) { +func (c *Command) getPersonalAccessToken(ctx context.Context) (*personalaccesstoken.Response, error) { client, err := personalaccesstoken.NewClient(c.Config) if err != nil { return nil, err } - return client.GetPersonalAccessToken(c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate) + return client.GetPersonalAccessToken(ctx, c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate) } diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go index bc748ab..5970142 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken_test.go +++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go @@ -2,6 +2,7 @@ package personalaccesstoken import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -170,7 +171,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) if tc.expectedError == "" { assert.NoError(t, err) diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go index 8bee484..2a0c146 100644 --- a/internal/command/receivepack/gitalycall_test.go +++ b/internal/command/receivepack/gitalycall_test.go @@ -2,6 +2,7 @@ package receivepack import ( "bytes" + "context" "testing" "github.com/sirupsen/logrus" @@ -42,7 +43,7 @@ func TestReceivePack(t *testing.T) { hook := testhelper.SetupLogger() - err = cmd.Execute() + err = cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String()) diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go index 7271264..4d5c686 100644 --- a/internal/command/receivepack/receivepack.go +++ b/internal/command/receivepack/receivepack.go @@ -1,6 +1,8 @@ package receivepack import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -15,14 +17,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -33,14 +35,14 @@ func (c *Command) Execute() error { ReadWriter: c.ReadWriter, EOFSent: true, } - return customAction.Execute(response) + return customAction.Execute(ctx, response) } return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go index a4632b4..44cb680 100644 --- a/internal/command/receivepack/receivepack_test.go +++ b/internal/command/receivepack/receivepack_test.go @@ -2,6 +2,7 @@ package receivepack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -18,7 +19,7 @@ func TestForbiddenAccess(t *testing.T) { cmd, _, cleanup := setup(t, "disallowed", requests) defer cleanup() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } @@ -26,7 +27,7 @@ func TestCustomReceivePack(t *testing.T) { cmd, output, cleanup := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t)) defer cleanup() - require.NoError(t, cmd.Execute()) + require.NoError(t, cmd.Execute(context.Background())) require.Equal(t, "customoutput", output.String()) } diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go index 5d2d709..9fcdde4 100644 --- a/internal/command/shared/accessverifier/accessverifier.go +++ b/internal/command/shared/accessverifier/accessverifier.go @@ -1,6 +1,7 @@ package accessverifier import ( + "context" "errors" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -18,13 +19,13 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) { +func (c *Command) Verify(ctx context.Context, action commandargs.CommandType, repo string) (*Response, error) { client, err := accessverifier.NewClient(c.Config) if err != nil { return nil, err } - response, err := client.Verify(c.Args, action, repo) + response, err := client.Verify(ctx, c.Args, action, repo) if err != nil { return nil, err } diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go index 998e622..8ad87b8 100644 --- a/internal/command/shared/accessverifier/accessverifier_test.go +++ b/internal/command/shared/accessverifier/accessverifier_test.go @@ -2,6 +2,7 @@ package accessverifier import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -65,7 +66,7 @@ func TestMissingUser(t *testing.T) { defer cleanup() cmd.Args = &commandargs.Shell{GitlabKeyId: "2"} - _, err := cmd.Verify(action, repo) + _, err := cmd.Verify(context.Background(), action, repo) require.Equal(t, "missing user", err.Error()) } @@ -75,7 +76,7 @@ func TestConsoleMessages(t *testing.T) { defer cleanup() cmd.Args = &commandargs.Shell{GitlabKeyId: "1"} - cmd.Verify(action, repo) + cmd.Verify(context.Background(), action, repo) require.Equal(t, "remote: \nremote: console\nremote: message\nremote: \n", errBuf.String()) require.Empty(t, outBuf.String()) diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go index 2ba1091..0675d36 100644 --- a/internal/command/shared/customaction/customaction.go +++ b/internal/command/shared/customaction/customaction.go @@ -2,6 +2,7 @@ package customaction import ( "bytes" + "context" "errors" "gitlab.com/gitlab-org/gitlab-shell/client" @@ -34,7 +35,7 @@ type Command struct { EOFSent bool } -func (c *Command) Execute(response *accessverifier.Response) error { +func (c *Command) Execute(ctx context.Context, response *accessverifier.Response) error { data := response.Payload.Data apiEndpoints := data.ApiEndpoints @@ -42,10 +43,10 @@ func (c *Command) Execute(response *accessverifier.Response) error { return errors.New("Custom action error: Empty API endpoints") } - return c.processApiEndpoints(response) + return c.processApiEndpoints(ctx, response) } -func (c *Command) processApiEndpoints(response *accessverifier.Response) error { +func (c *Command) processApiEndpoints(ctx context.Context, response *accessverifier.Response) error { client, err := gitlabnet.GetClient(c.Config) if err != nil { @@ -64,7 +65,7 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error { log.WithFields(fields).Info("Performing custom action") - response, err := c.performRequest(client, endpoint, request) + response, err := c.performRequest(ctx, client, endpoint, request) if err != nil { return err } @@ -95,8 +96,8 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error { return nil } -func (c *Command) performRequest(client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) { - response, err := client.DoRequest(http.MethodPost, endpoint, request) +func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) { + response, err := client.DoRequest(ctx, http.MethodPost, endpoint, request) if err != nil { return nil, err } diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go index 46c5f32..119da5b 100644 --- a/internal/command/shared/customaction/customaction_test.go +++ b/internal/command/shared/customaction/customaction_test.go @@ -2,6 +2,7 @@ package customaction import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -78,7 +79,7 @@ func TestExecuteEOFSent(t *testing.T) { EOFSent: true, } - require.NoError(t, cmd.Execute(response)) + require.NoError(t, cmd.Execute(context.Background(), response)) // expect printing of info message, "custom" string from the first request // and "output" string from the second request @@ -148,7 +149,7 @@ func TestExecuteNoEOFSent(t *testing.T) { EOFSent: false, } - require.NoError(t, cmd.Execute(response)) + require.NoError(t, cmd.Execute(context.Background(), response)) // expect printing of info message, "custom" string from the first request // and "output" string from the second request diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go index 2f13cc5..f0a9e7b 100644 --- a/internal/command/twofactorrecover/twofactorrecover.go +++ b/internal/command/twofactorrecover/twofactorrecover.go @@ -1,6 +1,7 @@ package twofactorrecover import ( + "context" "fmt" "strings" @@ -16,9 +17,9 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { if c.canContinue() { - c.displayRecoveryCodes() + c.displayRecoveryCodes(ctx) } else { fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") } @@ -38,8 +39,8 @@ func (c *Command) canContinue() bool { return answer == "yes" } -func (c *Command) displayRecoveryCodes() { - codes, err := c.getRecoveryCodes() +func (c *Command) displayRecoveryCodes(ctx context.Context) { + codes, err := c.getRecoveryCodes(ctx) if err == nil { messageWithCodes := @@ -54,12 +55,12 @@ func (c *Command) displayRecoveryCodes() { } } -func (c *Command) getRecoveryCodes() ([]string, error) { +func (c *Command) getRecoveryCodes(ctx context.Context) ([]string, error) { client, err := twofactorrecover.NewClient(c.Config) if err != nil { return nil, err } - return client.GetRecoveryCodes(c.Args) + return client.GetRecoveryCodes(ctx, c.Args) } diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go index d2f931b..ea6abd6 100644 --- a/internal/command/twofactorrecover/twofactorrecover_test.go +++ b/internal/command/twofactorrecover/twofactorrecover_test.go @@ -2,6 +2,7 @@ package twofactorrecover import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -127,7 +128,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) assert.NoError(t, err) assert.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go index eaeb2b7..f74093a 100644 --- a/internal/command/uploadarchive/gitalycall_test.go +++ b/internal/command/uploadarchive/gitalycall_test.go @@ -2,6 +2,7 @@ package uploadarchive import ( "bytes" + "context" "testing" "github.com/sirupsen/logrus" @@ -38,7 +39,7 @@ func TestUploadPack(t *testing.T) { hook := testhelper.SetupLogger() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "UploadArchive: "+repo, output.String()) diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go index 9d4fbe0..178b42b 100644 --- a/internal/command/uploadarchive/uploadarchive.go +++ b/internal/command/uploadarchive/uploadarchive.go @@ -1,6 +1,8 @@ package uploadarchive import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -14,14 +16,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -29,8 +31,8 @@ func (c *Command) Execute() error { return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go index 7b03009..5426569 100644 --- a/internal/command/uploadarchive/uploadarchive_test.go +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -2,6 +2,7 @@ package uploadarchive import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go index d6762a2..22189b8 100644 --- a/internal/command/uploadpack/gitalycall_test.go +++ b/internal/command/uploadpack/gitalycall_test.go @@ -2,6 +2,7 @@ package uploadpack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/assert" @@ -37,7 +38,7 @@ func TestUploadPack(t *testing.T) { hook := testhelper.SetupLogger() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "UploadPack: "+repo, output.String()) diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go index 56814d7..fca3823 100644 --- a/internal/command/uploadpack/uploadpack.go +++ b/internal/command/uploadpack/uploadpack.go @@ -1,6 +1,8 @@ package uploadpack import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -15,14 +17,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -33,14 +35,14 @@ func (c *Command) Execute() error { ReadWriter: c.ReadWriter, EOFSent: false, } - return customAction.Execute(response) + return customAction.Execute(ctx, response) } return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go index 7ea8e5d..20edb57 100644 --- a/internal/command/uploadpack/uploadpack_test.go +++ b/internal/command/uploadpack/uploadpack_test.go @@ -2,6 +2,7 @@ package uploadpack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } |