summaryrefslogtreecommitdiff
path: root/internal/command
diff options
context:
space:
mode:
Diffstat (limited to 'internal/command')
-rw-r--r--internal/command/authorizedkeys/authorized_keys.go13
-rw-r--r--internal/command/authorizedkeys/authorized_keys_test.go3
-rw-r--r--internal/command/authorizedprincipals/authorized_principals.go3
-rw-r--r--internal/command/authorizedprincipals/authorized_principals_test.go3
-rw-r--r--internal/command/command.go29
-rw-r--r--internal/command/command_test.go66
-rw-r--r--internal/command/discover/discover.go9
-rw-r--r--internal/command/discover/discover_test.go5
-rw-r--r--internal/command/healthcheck/healthcheck.go9
-rw-r--r--internal/command/healthcheck/healthcheck_test.go7
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate.go15
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate_test.go5
-rw-r--r--internal/command/personalaccesstoken/personalaccesstoken.go9
-rw-r--r--internal/command/personalaccesstoken/personalaccesstoken_test.go3
-rw-r--r--internal/command/receivepack/gitalycall_test.go3
-rw-r--r--internal/command/receivepack/receivepack.go12
-rw-r--r--internal/command/receivepack/receivepack_test.go5
-rw-r--r--internal/command/shared/accessverifier/accessverifier.go5
-rw-r--r--internal/command/shared/accessverifier/accessverifier_test.go5
-rw-r--r--internal/command/shared/customaction/customaction.go13
-rw-r--r--internal/command/shared/customaction/customaction_test.go5
-rw-r--r--internal/command/twofactorrecover/twofactorrecover.go13
-rw-r--r--internal/command/twofactorrecover/twofactorrecover_test.go3
-rw-r--r--internal/command/uploadarchive/gitalycall_test.go3
-rw-r--r--internal/command/uploadarchive/uploadarchive.go10
-rw-r--r--internal/command/uploadarchive/uploadarchive_test.go3
-rw-r--r--internal/command/uploadpack/gitalycall_test.go3
-rw-r--r--internal/command/uploadpack/uploadpack.go12
-rw-r--r--internal/command/uploadpack/uploadpack_test.go3
29 files changed, 200 insertions, 77 deletions
diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go
index 7554761..736aeed 100644
--- a/internal/command/authorizedkeys/authorized_keys.go
+++ b/internal/command/authorizedkeys/authorized_keys.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"fmt"
"strconv"
@@ -17,7 +18,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
// Do and return nothing when the expected and actual user don't match.
// This can happen when the user in sshd_config doesn't match the user
// trying to login. When nothing is printed, the user will be denied access.
@@ -27,15 +28,15 @@ func (c *Command) Execute() error {
return nil
}
- if err := c.printKeyLine(); err != nil {
+ if err := c.printKeyLine(ctx); err != nil {
return err
}
return nil
}
-func (c *Command) printKeyLine() error {
- response, err := c.getAuthorizedKey()
+func (c *Command) printKeyLine(ctx context.Context) error {
+ response, err := c.getAuthorizedKey(ctx)
if err != nil {
fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key))
return nil
@@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error {
return nil
}
-func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) {
+func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) {
client, err := authorizedkeys.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetByKey(c.Args.Key)
+ return client.GetByKey(ctx, c.Args.Key)
}
diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go
index e12f4fa..f15c34d 100644
--- a/internal/command/authorizedkeys/authorized_keys_test.go
+++ b/internal/command/authorizedkeys/authorized_keys_test.go
@@ -2,6 +2,7 @@ package authorizedkeys
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -97,7 +98,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go
index ab5f2f8..44f6c47 100644
--- a/internal/command/authorizedprincipals/authorized_principals.go
+++ b/internal/command/authorizedprincipals/authorized_principals.go
@@ -1,6 +1,7 @@
package authorizedprincipals
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -15,7 +16,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
if err := c.printPrincipalLines(); err != nil {
return err
}
diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go
index f11dd0f..ec97b65 100644
--- a/internal/command/authorizedprincipals/authorized_principals_test.go
+++ b/internal/command/authorizedprincipals/authorized_principals_test.go
@@ -2,6 +2,7 @@ package authorizedprincipals
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -54,7 +55,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
diff --git a/internal/command/command.go b/internal/command/command.go
index 283b4a1..c69219b 100644
--- a/internal/command/command.go
+++ b/internal/command/command.go
@@ -1,6 +1,8 @@
package command
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -16,10 +18,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/tracing"
)
type Command interface {
- Execute() error
+ Execute(ctx context.Context) error
}
func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
@@ -35,6 +40,28 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re
return nil, disallowedcommand.Error
}
+// ContextWithCorrelationID() will always return a background Context
+// with a correlation ID. It will first attempt to extract the ID from
+// an environment variable. If is not available, a random one will be
+// generated.
+func ContextWithCorrelationID() (context.Context, func()) {
+ ctx, finished := tracing.ExtractFromEnv(context.Background())
+ defer finished()
+
+ correlationID := correlation.ExtractFromContext(ctx)
+ if correlationID == "" {
+ correlationID, err := correlation.RandomID()
+ if err != nil {
+ log.WithError(err).Warn("unable to generate correlation ID")
+ } else {
+ log.Info("generated random correlation ID")
+ ctx = correlation.ContextWithCorrelation(ctx, correlationID)
+ }
+ }
+
+ return ctx, finished
+}
+
func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
switch e.Name {
case executable.GitlabShell:
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
index db55e7d..9160abf 100644
--- a/internal/command/command_test.go
+++ b/internal/command/command_test.go
@@ -2,6 +2,7 @@ package command
import (
"errors"
+ "os"
"testing"
"github.com/stretchr/testify/require"
@@ -20,6 +21,7 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+ "gitlab.com/gitlab-org/labkit/correlation"
)
var (
@@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) {
})
}
}
+
+func TestContextWithCorrelationID(t *testing.T) {
+ testCases := []struct {
+ name string
+ additionalEnv map[string]string
+ expectedCorrelationID string
+ }{
+ {
+ name: "no CORRELATION_ID in environment",
+ },
+ {
+ name: "CORRELATION_ID in environment",
+ additionalEnv: map[string]string{
+ "CORRELATION_ID": "abc123",
+ },
+ expectedCorrelationID: "abc123",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ resetEnvironment := addAdditionalEnv(tc.additionalEnv)
+ defer resetEnvironment()
+
+ ctx, finished := ContextWithCorrelationID()
+ require.NotNil(t, ctx, "ctx is nil")
+ require.NotNil(t, finished, "finished is nil")
+ correlationID := correlation.ExtractFromContext(ctx)
+ require.NotEmpty(t, correlationID)
+
+ if tc.expectedCorrelationID != "" {
+ require.Equal(t, tc.expectedCorrelationID, correlationID)
+ }
+ defer finished()
+ })
+ }
+}
+
+// addAdditionalEnv will configure additional environment values
+// and return a deferrable function to reset the environment to
+// it's original state after the test
+func addAdditionalEnv(envMap map[string]string) func() {
+ prevValues := map[string]string{}
+ unsetValues := []string{}
+ for k, v := range envMap {
+ value, exists := os.LookupEnv(k)
+ if exists {
+ prevValues[k] = value
+ } else {
+ unsetValues = append(unsetValues, k)
+ }
+ os.Setenv(k, v)
+ }
+
+ return func() {
+ for k, v := range prevValues {
+ os.Setenv(k, v)
+ }
+
+ for _, k := range unsetValues {
+ os.Unsetenv(k)
+ }
+
+ }
+}
diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go
index 3aa7456..822be32 100644
--- a/internal/command/discover/discover.go
+++ b/internal/command/discover/discover.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -15,8 +16,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
- response, err := c.getUserInfo()
+func (c *Command) Execute(ctx context.Context) error {
+ response, err := c.getUserInfo(ctx)
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
@@ -30,11 +31,11 @@ func (c *Command) Execute() error {
return nil
}
-func (c *Command) getUserInfo() (*discover.Response, error) {
+func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) {
client, err := discover.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetByCommandArgs(c.Args)
+ return client.GetByCommandArgs(ctx, c.Args)
}
diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go
index 8edbcb9..5431410 100644
--- a/internal/command/discover/discover_test.go
+++ b/internal/command/discover/discover_test.go
@@ -2,6 +2,7 @@ package discover
import (
"bytes"
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -83,7 +84,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
@@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, tc.expectedError)
diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go
index bbc73dc..b04eb0d 100644
--- a/internal/command/healthcheck/healthcheck.go
+++ b/internal/command/healthcheck/healthcheck.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
@@ -18,8 +19,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
- response, err := c.runCheck()
+func (c *Command) Execute(ctx context.Context) error {
+ response, err := c.runCheck(ctx)
if err != nil {
return fmt.Errorf("%v: FAILED - %v", apiMessage, err)
}
@@ -34,13 +35,13 @@ func (c *Command) Execute() error {
return nil
}
-func (c *Command) runCheck() (*healthcheck.Response, error) {
+func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) {
client, err := healthcheck.NewClient(c.Config)
if err != nil {
return nil, err
}
- response, err := client.Check()
+ response, err := client.Check(ctx)
if err != nil {
return nil, err
}
diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go
index 7479bcb..d05e563 100644
--- a/internal/command/healthcheck/healthcheck_test.go
+++ b/internal/command/healthcheck/healthcheck_test.go
@@ -2,6 +2,7 @@ package healthcheck
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -53,7 +54,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String())
@@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Error(t, err, "Redis available via internal API: FAILED")
require.Equal(t, "Internal API available: OK\n", buffer.String())
}
@@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)")
}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go
index 2aaac2a..dab69ab 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -34,7 +35,7 @@ type Payload struct {
ExpiresIn int `json:"expires_in,omitempty"`
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) < 3 {
return disallowedcommand.Error
@@ -49,12 +50,12 @@ func (c *Command) Execute() error {
return err
}
- accessResponse, err := c.verifyAccess(action, repo)
+ accessResponse, err := c.verifyAccess(ctx, action, repo)
if err != nil {
return err
}
- payload, err := c.authenticate(operation, repo, accessResponse.UserId)
+ payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
if err != nil {
// return nothing just like Ruby's GitlabShell#lfs_authenticate does
return nil
@@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) {
return action, nil
}
-func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(action, repo)
+ return cmd.Verify(ctx, action, repo)
}
-func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) {
+func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) {
client, err := lfsauthenticate.NewClient(c.Config, c.Args)
if err != nil {
return nil, err
}
- response, err := client.Authenticate(operation, repo, userId)
+ response, err := client.Authenticate(ctx, operation, repo, userId)
if err != nil {
return nil, err
}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go
index a1c7aec..55998ab 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate_test.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go
@@ -2,6 +2,7 @@ package lfsauthenticate
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -54,7 +55,7 @@ func TestFailedRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Error(t, err)
require.Equal(t, tc.expectedOutput, err.Error())
@@ -146,7 +147,7 @@ func TestLfsAuthenticateRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, output.String())
diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go
index b283890..6f3d03e 100644
--- a/internal/command/personalaccesstoken/personalaccesstoken.go
+++ b/internal/command/personalaccesstoken/personalaccesstoken.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"errors"
"fmt"
"strconv"
@@ -31,13 +32,13 @@ type tokenArgs struct {
ExpiresDate string // Calculated, a TTL is passed from command-line.
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
err := c.parseTokenArgs()
if err != nil {
return err
}
- response, err := c.getPersonalAccessToken()
+ response, err := c.getPersonalAccessToken(ctx)
if err != nil {
return err
}
@@ -76,11 +77,11 @@ func (c *Command) parseTokenArgs() error {
return nil
}
-func (c *Command) getPersonalAccessToken() (*personalaccesstoken.Response, error) {
+func (c *Command) getPersonalAccessToken(ctx context.Context) (*personalaccesstoken.Response, error) {
client, err := personalaccesstoken.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetPersonalAccessToken(c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate)
+ return client.GetPersonalAccessToken(ctx, c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate)
}
diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go
index bc748ab..5970142 100644
--- a/internal/command/personalaccesstoken/personalaccesstoken_test.go
+++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go
@@ -2,6 +2,7 @@ package personalaccesstoken
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -170,7 +171,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
if tc.expectedError == "" {
assert.NoError(t, err)
diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go
index 8bee484..2a0c146 100644
--- a/internal/command/receivepack/gitalycall_test.go
+++ b/internal/command/receivepack/gitalycall_test.go
@@ -2,6 +2,7 @@ package receivepack
import (
"bytes"
+ "context"
"testing"
"github.com/sirupsen/logrus"
@@ -42,7 +43,7 @@ func TestReceivePack(t *testing.T) {
hook := testhelper.SetupLogger()
- err = cmd.Execute()
+ err = cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String())
diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go
index 7271264..4d5c686 100644
--- a/internal/command/receivepack/receivepack.go
+++ b/internal/command/receivepack/receivepack.go
@@ -1,6 +1,8 @@
package receivepack
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -15,14 +17,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -33,14 +35,14 @@ func (c *Command) Execute() error {
ReadWriter: c.ReadWriter,
EOFSent: true,
}
- return customAction.Execute(response)
+ return customAction.Execute(ctx, response)
}
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go
index a4632b4..44cb680 100644
--- a/internal/command/receivepack/receivepack_test.go
+++ b/internal/command/receivepack/receivepack_test.go
@@ -2,6 +2,7 @@ package receivepack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -18,7 +19,7 @@ func TestForbiddenAccess(t *testing.T) {
cmd, _, cleanup := setup(t, "disallowed", requests)
defer cleanup()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
@@ -26,7 +27,7 @@ func TestCustomReceivePack(t *testing.T) {
cmd, output, cleanup := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t))
defer cleanup()
- require.NoError(t, cmd.Execute())
+ require.NoError(t, cmd.Execute(context.Background()))
require.Equal(t, "customoutput", output.String())
}
diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go
index 5d2d709..9fcdde4 100644
--- a/internal/command/shared/accessverifier/accessverifier.go
+++ b/internal/command/shared/accessverifier/accessverifier.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"errors"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -18,13 +19,13 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) {
+func (c *Command) Verify(ctx context.Context, action commandargs.CommandType, repo string) (*Response, error) {
client, err := accessverifier.NewClient(c.Config)
if err != nil {
return nil, err
}
- response, err := client.Verify(c.Args, action, repo)
+ response, err := client.Verify(ctx, c.Args, action, repo)
if err != nil {
return nil, err
}
diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go
index 998e622..8ad87b8 100644
--- a/internal/command/shared/accessverifier/accessverifier_test.go
+++ b/internal/command/shared/accessverifier/accessverifier_test.go
@@ -2,6 +2,7 @@ package accessverifier
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -65,7 +66,7 @@ func TestMissingUser(t *testing.T) {
defer cleanup()
cmd.Args = &commandargs.Shell{GitlabKeyId: "2"}
- _, err := cmd.Verify(action, repo)
+ _, err := cmd.Verify(context.Background(), action, repo)
require.Equal(t, "missing user", err.Error())
}
@@ -75,7 +76,7 @@ func TestConsoleMessages(t *testing.T) {
defer cleanup()
cmd.Args = &commandargs.Shell{GitlabKeyId: "1"}
- cmd.Verify(action, repo)
+ cmd.Verify(context.Background(), action, repo)
require.Equal(t, "remote: \nremote: console\nremote: message\nremote: \n", errBuf.String())
require.Empty(t, outBuf.String())
diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go
index 2ba1091..0675d36 100644
--- a/internal/command/shared/customaction/customaction.go
+++ b/internal/command/shared/customaction/customaction.go
@@ -2,6 +2,7 @@ package customaction
import (
"bytes"
+ "context"
"errors"
"gitlab.com/gitlab-org/gitlab-shell/client"
@@ -34,7 +35,7 @@ type Command struct {
EOFSent bool
}
-func (c *Command) Execute(response *accessverifier.Response) error {
+func (c *Command) Execute(ctx context.Context, response *accessverifier.Response) error {
data := response.Payload.Data
apiEndpoints := data.ApiEndpoints
@@ -42,10 +43,10 @@ func (c *Command) Execute(response *accessverifier.Response) error {
return errors.New("Custom action error: Empty API endpoints")
}
- return c.processApiEndpoints(response)
+ return c.processApiEndpoints(ctx, response)
}
-func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
+func (c *Command) processApiEndpoints(ctx context.Context, response *accessverifier.Response) error {
client, err := gitlabnet.GetClient(c.Config)
if err != nil {
@@ -64,7 +65,7 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
log.WithFields(fields).Info("Performing custom action")
- response, err := c.performRequest(client, endpoint, request)
+ response, err := c.performRequest(ctx, client, endpoint, request)
if err != nil {
return err
}
@@ -95,8 +96,8 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
return nil
}
-func (c *Command) performRequest(client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) {
- response, err := client.DoRequest(http.MethodPost, endpoint, request)
+func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) {
+ response, err := client.DoRequest(ctx, http.MethodPost, endpoint, request)
if err != nil {
return nil, err
}
diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go
index 46c5f32..119da5b 100644
--- a/internal/command/shared/customaction/customaction_test.go
+++ b/internal/command/shared/customaction/customaction_test.go
@@ -2,6 +2,7 @@ package customaction
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -78,7 +79,7 @@ func TestExecuteEOFSent(t *testing.T) {
EOFSent: true,
}
- require.NoError(t, cmd.Execute(response))
+ require.NoError(t, cmd.Execute(context.Background(), response))
// expect printing of info message, "custom" string from the first request
// and "output" string from the second request
@@ -148,7 +149,7 @@ func TestExecuteNoEOFSent(t *testing.T) {
EOFSent: false,
}
- require.NoError(t, cmd.Execute(response))
+ require.NoError(t, cmd.Execute(context.Background(), response))
// expect printing of info message, "custom" string from the first request
// and "output" string from the second request
diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go
index 2f13cc5..f0a9e7b 100644
--- a/internal/command/twofactorrecover/twofactorrecover.go
+++ b/internal/command/twofactorrecover/twofactorrecover.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"fmt"
"strings"
@@ -16,9 +17,9 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
if c.canContinue() {
- c.displayRecoveryCodes()
+ c.displayRecoveryCodes(ctx)
} else {
fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
}
@@ -38,8 +39,8 @@ func (c *Command) canContinue() bool {
return answer == "yes"
}
-func (c *Command) displayRecoveryCodes() {
- codes, err := c.getRecoveryCodes()
+func (c *Command) displayRecoveryCodes(ctx context.Context) {
+ codes, err := c.getRecoveryCodes(ctx)
if err == nil {
messageWithCodes :=
@@ -54,12 +55,12 @@ func (c *Command) displayRecoveryCodes() {
}
}
-func (c *Command) getRecoveryCodes() ([]string, error) {
+func (c *Command) getRecoveryCodes(ctx context.Context) ([]string, error) {
client, err := twofactorrecover.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetRecoveryCodes(c.Args)
+ return client.GetRecoveryCodes(ctx, c.Args)
}
diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go
index d2f931b..ea6abd6 100644
--- a/internal/command/twofactorrecover/twofactorrecover_test.go
+++ b/internal/command/twofactorrecover/twofactorrecover_test.go
@@ -2,6 +2,7 @@ package twofactorrecover
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -127,7 +128,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, output.String())
diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go
index eaeb2b7..f74093a 100644
--- a/internal/command/uploadarchive/gitalycall_test.go
+++ b/internal/command/uploadarchive/gitalycall_test.go
@@ -2,6 +2,7 @@ package uploadarchive
import (
"bytes"
+ "context"
"testing"
"github.com/sirupsen/logrus"
@@ -38,7 +39,7 @@ func TestUploadPack(t *testing.T) {
hook := testhelper.SetupLogger()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "UploadArchive: "+repo, output.String())
diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go
index 9d4fbe0..178b42b 100644
--- a/internal/command/uploadarchive/uploadarchive.go
+++ b/internal/command/uploadarchive/uploadarchive.go
@@ -1,6 +1,8 @@
package uploadarchive
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -14,14 +16,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -29,8 +31,8 @@ func (c *Command) Execute() error {
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go
index 7b03009..5426569 100644
--- a/internal/command/uploadarchive/uploadarchive_test.go
+++ b/internal/command/uploadarchive/uploadarchive_test.go
@@ -2,6 +2,7 @@ package uploadarchive
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go
index d6762a2..22189b8 100644
--- a/internal/command/uploadpack/gitalycall_test.go
+++ b/internal/command/uploadpack/gitalycall_test.go
@@ -2,6 +2,7 @@ package uploadpack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -37,7 +38,7 @@ func TestUploadPack(t *testing.T) {
hook := testhelper.SetupLogger()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "UploadPack: "+repo, output.String())
diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go
index 56814d7..fca3823 100644
--- a/internal/command/uploadpack/uploadpack.go
+++ b/internal/command/uploadpack/uploadpack.go
@@ -1,6 +1,8 @@
package uploadpack
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -15,14 +17,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -33,14 +35,14 @@ func (c *Command) Execute() error {
ReadWriter: c.ReadWriter,
EOFSent: false,
}
- return customAction.Execute(response)
+ return customAction.Execute(ctx, response)
}
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go
index 7ea8e5d..20edb57 100644
--- a/internal/command/uploadpack/uploadpack_test.go
+++ b/internal/command/uploadpack/uploadpack_test.go
@@ -2,6 +2,7 @@ package uploadpack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}