diff options
-rw-r--r-- | go/cmd/gitlab-shell/main.go | 6 | ||||
-rw-r--r-- | go/internal/command/command.go | 12 | ||||
-rw-r--r-- | go/internal/command/command_test.go | 4 | ||||
-rw-r--r-- | go/internal/command/discover/discover.go | 11 | ||||
-rw-r--r-- | go/internal/command/discover/discover_test.go | 16 | ||||
-rw-r--r-- | go/internal/command/fallback/fallback.go | 4 | ||||
-rw-r--r-- | go/internal/command/fallback/fallback_test.go | 6 | ||||
-rw-r--r-- | go/internal/command/twofactorrecover/twofactorrecover.go | 25 | ||||
-rw-r--r-- | go/internal/command/twofactorrecover/twofactorrecover_test.go | 8 |
9 files changed, 52 insertions, 40 deletions
diff --git a/go/cmd/gitlab-shell/main.go b/go/cmd/gitlab-shell/main.go index 6e39d8b..a8aef75 100644 --- a/go/cmd/gitlab-shell/main.go +++ b/go/cmd/gitlab-shell/main.go @@ -29,7 +29,7 @@ func findRootDir() (string, error) { func execRuby(rootDir string, readWriter *readwriter.ReadWriter) { cmd := &fallback.Command{RootDir: rootDir, Args: os.Args} - if err := cmd.Execute(readWriter); err != nil { + if err := cmd.Execute(); err != nil { fmt.Fprintf(readWriter.ErrOut, "Failed to exec: %v\n", err) os.Exit(1) } @@ -56,7 +56,7 @@ func main() { execRuby(rootDir, readWriter) } - cmd, err := command.New(os.Args, config) + cmd, err := command.New(os.Args, config, readWriter) if err != nil { // For now this could happen if `SSH_CONNECTION` is not set on // the environment @@ -66,7 +66,7 @@ func main() { // The command will write to STDOUT on execution or replace the current // process in case of the `fallback.Command` - if err = cmd.Execute(readWriter); err != nil { + if err = cmd.Execute(); err != nil { fmt.Fprintf(readWriter.ErrOut, "%v\n", err) os.Exit(1) } diff --git a/go/internal/command/command.go b/go/internal/command/command.go index 560e0b2..0ceb7fc 100644 --- a/go/internal/command/command.go +++ b/go/internal/command/command.go @@ -10,10 +10,10 @@ import ( ) type Command interface { - Execute(*readwriter.ReadWriter) error + Execute() error } -func New(arguments []string, config *config.Config) (Command, error) { +func New(arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) { args, err := commandargs.Parse(arguments) if err != nil { @@ -21,18 +21,18 @@ func New(arguments []string, config *config.Config) (Command, error) { } if config.FeatureEnabled(string(args.CommandType)) { - return buildCommand(args, config), nil + return buildCommand(args, config, readWriter), nil } return &fallback.Command{RootDir: config.RootDir, Args: arguments}, nil } -func buildCommand(args *commandargs.CommandArgs, config *config.Config) Command { +func buildCommand(args *commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command { switch args.CommandType { case commandargs.Discover: - return &discover.Command{Config: config, Args: args} + return &discover.Command{Config: config, Args: args, ReadWriter: readWriter} case commandargs.TwoFactorRecover: - return &twofactorrecover.Command{Config: config, Args: args} + return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter} } return nil diff --git a/go/internal/command/command_test.go b/go/internal/command/command_test.go index 42c5112..228dc7a 100644 --- a/go/internal/command/command_test.go +++ b/go/internal/command/command_test.go @@ -65,7 +65,7 @@ func TestNew(t *testing.T) { restoreEnv := testhelper.TempEnv(tc.environment) defer restoreEnv() - command, err := New(tc.arguments, tc.config) + command, err := New(tc.arguments, tc.config, nil) assert.NoError(t, err) assert.IsType(t, tc.expectedType, command) @@ -78,7 +78,7 @@ func TestFailingNew(t *testing.T) { restoreEnv := testhelper.TempEnv(map[string]string{}) defer restoreEnv() - _, err := New([]string{}, &config.Config{}) + _, err := New([]string{}, &config.Config{}, nil) assert.Error(t, err, "Only ssh allowed") }) diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go index 9bb442f..7d4ad2b 100644 --- a/go/internal/command/discover/discover.go +++ b/go/internal/command/discover/discover.go @@ -10,20 +10,21 @@ import ( ) type Command struct { - Config *config.Config - Args *commandargs.CommandArgs + Config *config.Config + Args *commandargs.CommandArgs + ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(readWriter *readwriter.ReadWriter) error { +func (c *Command) Execute() error { response, err := c.getUserInfo() if err != nil { return fmt.Errorf("Failed to get username: %v", err) } if response.IsAnonymous() { - fmt.Fprintf(readWriter.Out, "Welcome to GitLab, Anonymous!\n") + fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, Anonymous!\n") } else { - fmt.Fprintf(readWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) + fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) } return nil diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go index f85add8..40c0d4e 100644 --- a/go/internal/command/discover/discover_test.go +++ b/go/internal/command/discover/discover_test.go @@ -78,10 +78,14 @@ func TestExecute(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments} buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } - err := cmd.Execute(&readwriter.ReadWriter{Out: buffer}) + err := cmd.Execute() assert.NoError(t, err) assert.Equal(t, tc.expectedOutput, buffer.String()) @@ -118,10 +122,14 @@ func TestFailingExecute(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments} buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } - err := cmd.Execute(&readwriter.ReadWriter{Out: buffer}) + err := cmd.Execute() assert.Empty(t, buffer.String()) assert.EqualError(t, err, tc.expectedError) diff --git a/go/internal/command/fallback/fallback.go b/go/internal/command/fallback/fallback.go index 71e2a98..f525a57 100644 --- a/go/internal/command/fallback/fallback.go +++ b/go/internal/command/fallback/fallback.go @@ -4,8 +4,6 @@ import ( "os" "path/filepath" "syscall" - - "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" ) type Command struct { @@ -22,7 +20,7 @@ const ( RubyProgram = "gitlab-shell-ruby" ) -func (c *Command) Execute(*readwriter.ReadWriter) error { +func (c *Command) Execute() error { rubyCmd := filepath.Join(c.RootDir, "bin", RubyProgram) // Ensure rubyArgs[0] is the full path to gitlab-shell-ruby diff --git a/go/internal/command/fallback/fallback_test.go b/go/internal/command/fallback/fallback_test.go index 2d67b14..afd752b 100644 --- a/go/internal/command/fallback/fallback_test.go +++ b/go/internal/command/fallback/fallback_test.go @@ -49,7 +49,7 @@ func TestExecuteExecsCommandSuccesfully(t *testing.T) { fake.Setup() defer fake.Cleanup() - require.NoError(t, cmd.Execute(nil)) + require.NoError(t, cmd.Execute()) require.True(t, fake.Called) require.Equal(t, fake.Filename, "/tmp/bin/gitlab-shell-ruby") require.Equal(t, fake.Args, []string{"/tmp/bin/gitlab-shell-ruby", "foo", "bar"}) @@ -64,12 +64,12 @@ func TestExecuteExecsCommandOnError(t *testing.T) { fake.Setup() defer fake.Cleanup() - require.Error(t, cmd.Execute(nil)) + require.Error(t, cmd.Execute()) require.True(t, fake.Called) } func TestExecuteGivenNonexistentCommand(t *testing.T) { cmd := &Command{RootDir: "/tmp/does/not/exist", Args: fakeArgs} - require.Error(t, cmd.Execute(nil)) + require.Error(t, cmd.Execute()) } diff --git a/go/internal/command/twofactorrecover/twofactorrecover.go b/go/internal/command/twofactorrecover/twofactorrecover.go index e77a334..faa35db 100644 --- a/go/internal/command/twofactorrecover/twofactorrecover.go +++ b/go/internal/command/twofactorrecover/twofactorrecover.go @@ -11,33 +11,34 @@ import ( ) type Command struct { - Config *config.Config - Args *commandargs.CommandArgs + Config *config.Config + Args *commandargs.CommandArgs + ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(readWriter *readwriter.ReadWriter) error { - if c.canContinue(readWriter) { - c.displayRecoveryCodes(readWriter) +func (c *Command) Execute() error { + if c.canContinue() { + c.displayRecoveryCodes() } else { - fmt.Fprintln(readWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") + fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") } return nil } -func (c *Command) canContinue(readWriter *readwriter.ReadWriter) bool { +func (c *Command) canContinue() bool { question := "Are you sure you want to generate new two-factor recovery codes?\n" + "Any existing recovery codes you saved will be invalidated. (yes/no)" - fmt.Fprintln(readWriter.Out, question) + fmt.Fprintln(c.ReadWriter.Out, question) var answer string - fmt.Fscanln(readWriter.In, &answer) + fmt.Fscanln(c.ReadWriter.In, &answer) return answer == "yes" } -func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) { +func (c *Command) displayRecoveryCodes() { codes, err := c.getRecoveryCodes() if err == nil { @@ -47,9 +48,9 @@ func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) { "\n\nDuring sign in, use one of the codes above when prompted for\n" + "your two-factor code. Then, visit your Profile Settings and add\n" + "a new device so you do not lose access to your account again.\n" - fmt.Fprint(readWriter.Out, messageWithCodes) + fmt.Fprint(c.ReadWriter.Out, messageWithCodes) } else { - fmt.Fprintf(readWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err) + fmt.Fprintf(c.ReadWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err) } } diff --git a/go/internal/command/twofactorrecover/twofactorrecover_test.go b/go/internal/command/twofactorrecover/twofactorrecover_test.go index be76520..bcca12a 100644 --- a/go/internal/command/twofactorrecover/twofactorrecover_test.go +++ b/go/internal/command/twofactorrecover/twofactorrecover_test.go @@ -122,9 +122,13 @@ func TestExecute(t *testing.T) { output := &bytes.Buffer{} input := bytes.NewBufferString(tc.answer) - cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, + } - err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input}) + err := cmd.Execute() assert.NoError(t, err) assert.Equal(t, tc.expectedOutput, output.String()) |