diff options
author | Nick Thomas <nick@gitlab.com> | 2019-03-15 17:16:17 +0000 |
---|---|---|
committer | Nick Thomas <nick@gitlab.com> | 2019-03-15 17:16:17 +0000 |
commit | f237aba6df1c1873f1f9d5ba18c3b8924d85cb51 (patch) | |
tree | 22d69b9450693bb153e58dbe8b7cd6feb3f8e1e0 | |
parent | 049beb74303a03d9fa598d23b150e0ccea3cd60d (diff) | |
parent | 83c0f18e1de04b3bad9c424084e738e911c47336 (diff) | |
download | gitlab-shell-f237aba6df1c1873f1f9d5ba18c3b8924d85cb51.tar.gz |
Merge branch 'bvl-discover-command' into 'master'
Call gitlab "/internal/discover" from go
Closes #175
See merge request gitlab-org/gitlab-shell!283
-rw-r--r-- | go/cmd/gitlab-shell/main.go | 19 | ||||
-rw-r--r-- | go/internal/command/command.go | 3 | ||||
-rw-r--r-- | go/internal/command/discover/discover.go | 34 | ||||
-rw-r--r-- | go/internal/command/discover/discover_test.go | 131 | ||||
-rw-r--r-- | go/internal/command/fallback/fallback.go | 4 | ||||
-rw-r--r-- | go/internal/command/reporting/reporter.go | 8 | ||||
-rw-r--r-- | go/internal/gitlabnet/client.go | 77 | ||||
-rw-r--r-- | go/internal/gitlabnet/client_test.go | 131 | ||||
-rw-r--r-- | go/internal/gitlabnet/discover/client.go | 76 | ||||
-rw-r--r-- | go/internal/gitlabnet/discover/client_test.go | 131 | ||||
-rw-r--r-- | go/internal/gitlabnet/socketclient.go | 46 | ||||
-rw-r--r-- | go/internal/gitlabnet/testserver/testserver.go | 56 | ||||
-rw-r--r-- | spec/gitlab_shell_gitlab_shell_spec.rb | 33 |
13 files changed, 729 insertions, 20 deletions
diff --git a/go/cmd/gitlab-shell/main.go b/go/cmd/gitlab-shell/main.go index 07623b4..2ed319d 100644 --- a/go/cmd/gitlab-shell/main.go +++ b/go/cmd/gitlab-shell/main.go @@ -7,25 +7,28 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/go/internal/command" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" ) var ( - binDir string - rootDir string + binDir string + rootDir string + reporter *reporting.Reporter ) func init() { binDir = filepath.Dir(os.Args[0]) rootDir = filepath.Dir(binDir) + reporter = &reporting.Reporter{Out: os.Stdout, ErrOut: os.Stderr} } // rubyExec will never return. It either replaces the current process with a // Ruby interpreter, or outputs an error and kills the process. func execRuby() { cmd := &fallback.Command{} - if err := cmd.Execute(); err != nil { - fmt.Fprintf(os.Stderr, "Failed to exec: %v\n", err) + if err := cmd.Execute(reporter); err != nil { + fmt.Fprintf(reporter.ErrOut, "Failed to exec: %v\n", err) os.Exit(1) } } @@ -35,7 +38,7 @@ func main() { // warning as this isn't something we can sustain indefinitely config, err := config.NewFromDir(rootDir) if err != nil { - fmt.Fprintln(os.Stderr, "Failed to read config, falling back to gitlab-shell-ruby") + fmt.Fprintln(reporter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby") execRuby() } @@ -43,14 +46,14 @@ func main() { if err != nil { // For now this could happen if `SSH_CONNECTION` is not set on // the environment - fmt.Fprintf(os.Stderr, "%v\n", err) + fmt.Fprintf(reporter.ErrOut, "%v\n", err) os.Exit(1) } // The command will write to STDOUT on execution or replace the current // process in case of the `fallback.Command` - if err = cmd.Execute(); err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) + if err = cmd.Execute(reporter); err != nil { + fmt.Fprintf(reporter.ErrOut, "%v\n", err) os.Exit(1) } } diff --git a/go/internal/command/command.go b/go/internal/command/command.go index cb2acdc..d4649de 100644 --- a/go/internal/command/command.go +++ b/go/internal/command/command.go @@ -4,11 +4,12 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" ) type Command interface { - Execute() error + Execute(*reporting.Reporter) error } func New(arguments []string, config *config.Config) (Command, error) { diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go index 63a7a32..8ad2868 100644 --- a/go/internal/command/discover/discover.go +++ b/go/internal/command/discover/discover.go @@ -4,7 +4,9 @@ import ( "fmt" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" ) type Command struct { @@ -12,6 +14,34 @@ type Command struct { Args *commandargs.CommandArgs } -func (c *Command) Execute() error { - return fmt.Errorf("No feature is implemented yet") +func (c *Command) Execute(reporter *reporting.Reporter) error { + response, err := c.getUserInfo() + if err != nil { + return fmt.Errorf("Failed to get username: %v", err) + } + + if response.IsAnonymous() { + fmt.Fprintf(reporter.Out, "Welcome to GitLab, Anonymous!\n") + } else { + fmt.Fprintf(reporter.Out, "Welcome to GitLab, @%s!\n", response.Username) + } + + return nil +} + +func (c *Command) getUserInfo() (*discover.Response, error) { + client, err := discover.NewClient(c.Config) + if err != nil { + return nil, err + } + + if c.Args.GitlabKeyId != "" { + return client.GetByKeyId(c.Args.GitlabKeyId) + } else if c.Args.GitlabUsername != "" { + return client.GetByUsername(c.Args.GitlabUsername) + } else { + // There was no 'who' information, this matches the ruby error + // message. + return nil, fmt.Errorf("who='' is invalid") + } } diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go new file mode 100644 index 0000000..ec6f931 --- /dev/null +++ b/go/internal/command/discover/discover_test.go @@ -0,0 +1,131 @@ +package discover + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +var ( + testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket} + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("key_id") == "1" || r.URL.Query().Get("username") == "alex-doe" { + body := map[string]interface{}{ + "id": 2, + "username": "alex-doe", + "name": "Alex Doe", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_message" { + body := map[string]string{ + "message": "Forbidden!", + } + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken" { + w.WriteHeader(http.StatusInternalServerError) + } else { + fmt.Fprint(w, "null") + } + }, + }, + } +) + +func TestExecute(t *testing.T) { + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.CommandArgs + expectedOutput string + }{ + { + desc: "With a known username", + arguments: &commandargs.CommandArgs{GitlabUsername: "alex-doe"}, + expectedOutput: "Welcome to GitLab, @alex-doe!\n", + }, + { + desc: "With a known key id", + arguments: &commandargs.CommandArgs{GitlabKeyId: "1"}, + expectedOutput: "Welcome to GitLab, @alex-doe!\n", + }, + { + desc: "With an unknown key", + arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"}, + expectedOutput: "Welcome to GitLab, Anonymous!\n", + }, + { + desc: "With an unknown username", + arguments: &commandargs.CommandArgs{GitlabUsername: "unknown"}, + expectedOutput: "Welcome to GitLab, Anonymous!\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + cmd := &Command{Config: testConfig, Args: tc.arguments} + buffer := &bytes.Buffer{} + + err := cmd.Execute(&reporting.Reporter{Out: buffer}) + + assert.NoError(t, err) + assert.Equal(t, tc.expectedOutput, buffer.String()) + }) + } +} + +func TestFailingExecute(t *testing.T) { + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.CommandArgs + expectedError string + }{ + { + desc: "With missing arguments", + arguments: &commandargs.CommandArgs{}, + expectedError: "Failed to get username: who='' is invalid", + }, + { + desc: "When the API returns an error", + arguments: &commandargs.CommandArgs{GitlabUsername: "broken_message"}, + expectedError: "Failed to get username: Forbidden!", + }, + { + desc: "When the API fails", + arguments: &commandargs.CommandArgs{GitlabUsername: "broken"}, + expectedError: "Failed to get username: Internal API error (500)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + cmd := &Command{Config: testConfig, Args: tc.arguments} + buffer := &bytes.Buffer{} + + err := cmd.Execute(&reporting.Reporter{Out: buffer}) + + assert.Empty(t, buffer.String()) + assert.EqualError(t, err, tc.expectedError) + }) + } +} diff --git a/go/internal/command/fallback/fallback.go b/go/internal/command/fallback/fallback.go index a136657..a2c73ed 100644 --- a/go/internal/command/fallback/fallback.go +++ b/go/internal/command/fallback/fallback.go @@ -4,6 +4,8 @@ import ( "os" "path/filepath" "syscall" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" ) type Command struct{} @@ -12,7 +14,7 @@ var ( binDir = filepath.Dir(os.Args[0]) ) -func (c *Command) Execute() error { +func (c *Command) Execute(_ *reporting.Reporter) error { rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby") execErr := syscall.Exec(rubyCmd, os.Args, os.Environ()) return execErr diff --git a/go/internal/command/reporting/reporter.go b/go/internal/command/reporting/reporter.go new file mode 100644 index 0000000..74bca59 --- /dev/null +++ b/go/internal/command/reporting/reporter.go @@ -0,0 +1,8 @@ +package reporting + +import "io" + +type Reporter struct { + Out io.Writer + ErrOut io.Writer +} diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go new file mode 100644 index 0000000..abc218f --- /dev/null +++ b/go/internal/gitlabnet/client.go @@ -0,0 +1,77 @@ +package gitlabnet + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" +) + +const ( + internalApiPath = "/api/v4/internal" + secretHeaderName = "Gitlab-Shared-Secret" +) + +type GitlabClient interface { + Get(path string) (*http.Response, error) + // TODO: implement posts + // Post(path string) (http.Response, error) +} + +type ErrorResponse struct { + Message string `json:"message"` +} + +func GetClient(config *config.Config) (GitlabClient, error) { + url := config.GitlabUrl + if strings.HasPrefix(url, UnixSocketProtocol) { + return buildSocketClient(config), nil + } + + return nil, fmt.Errorf("Unsupported protocol") +} + +func normalizePath(path string) string { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + if !strings.HasPrefix(path, internalApiPath) { + path = internalApiPath + path + } + return path +} + +func parseError(resp *http.Response) error { + if resp.StatusCode >= 200 && resp.StatusCode <= 299 { + return nil + } + defer resp.Body.Close() + parsedResponse := &ErrorResponse{} + + if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { + return fmt.Errorf("Internal API error (%v)", resp.StatusCode) + } else { + return fmt.Errorf(parsedResponse.Message) + } + +} + +func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) { + encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret)) + request.Header.Set(secretHeaderName, encodedSecret) + + response, err := client.Do(request) + if err != nil { + return nil, fmt.Errorf("Internal API unreachable") + } + + if err := parseError(response); err != nil { + return nil, err + } + + return response, nil +} diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go new file mode 100644 index 0000000..f69f284 --- /dev/null +++ b/go/internal/gitlabnet/client_test.go @@ -0,0 +1,131 @@ +package gitlabnet + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +func TestClients(t *testing.T) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/hello", + Handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello") + }, + }, + { + Path: "/api/v4/internal/auth", + Handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, r.Header.Get(secretHeaderName)) + }, + }, + { + Path: "/api/v4/internal/error", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + body := map[string]string{ + "message": "Don't do that", + } + json.NewEncoder(w).Encode(body) + }, + }, + { + Path: "/api/v4/internal/broken", + Handler: func(w http.ResponseWriter, r *http.Request) { + panic("Broken") + }, + }, + } + testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"} + + testCases := []struct { + desc string + client GitlabClient + server func([]testserver.TestRequestHandler) (func(), error) + }{ + { + desc: "Socket client", + client: buildSocketClient(testConfig), + server: testserver.StartSocketHttpServer, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + cleanup, err := tc.server(requests) + defer cleanup() + require.NoError(t, err) + + testBrokenRequest(t, tc.client) + testSuccessfulGet(t, tc.client) + testMissing(t, tc.client) + testErrorMessage(t, tc.client) + testAuthenticationHeader(t, tc.client) + }) + } +} + +func testSuccessfulGet(t *testing.T, client GitlabClient) { + t.Run("Successful get", func(t *testing.T) { + response, err := client.Get("/hello") + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, string(responseBody), "Hello") + }) +} + +func testMissing(t *testing.T, client GitlabClient) { + t.Run("Missing error", func(t *testing.T) { + response, err := client.Get("/missing") + assert.EqualError(t, err, "Internal API error (404)") + assert.Nil(t, response) + }) +} + +func testErrorMessage(t *testing.T, client GitlabClient) { + t.Run("Error with message", func(t *testing.T) { + response, err := client.Get("/error") + assert.EqualError(t, err, "Don't do that") + assert.Nil(t, response) + }) +} + +func testBrokenRequest(t *testing.T, client GitlabClient) { + t.Run("Broken request", func(t *testing.T) { + response, err := client.Get("/broken") + assert.EqualError(t, err, "Internal API unreachable") + assert.Nil(t, response) + }) +} + +func testAuthenticationHeader(t *testing.T, client GitlabClient) { + t.Run("Authentication headers", func(t *testing.T) { + response, err := client.Get("/auth") + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + header, err := base64.StdEncoding.DecodeString(string(responseBody)) + require.NoError(t, err) + assert.Equal(t, "sssh, it's a secret", string(header)) + }) +} diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go new file mode 100644 index 0000000..8df78fb --- /dev/null +++ b/go/internal/gitlabnet/discover/client.go @@ -0,0 +1,76 @@ +package discover + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" +) + +type Client struct { + config *config.Config + client gitlabnet.GitlabClient +} + +type Response struct { + UserId int64 `json:"id"` + Name string `json:"name"` + Username string `json:"username"` +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client}, nil +} + +func (c *Client) GetByKeyId(keyId string) (*Response, error) { + params := url.Values{} + params.Add("key_id", keyId) + + return c.getResponse(params) +} + +func (c *Client) GetByUsername(username string) (*Response, error) { + params := url.Values{} + params.Add("username", username) + + return c.getResponse(params) +} + +func (c *Client) parseResponse(resp *http.Response) (*Response, error) { + parsedResponse := &Response{} + + if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { + return nil, err + } else { + return parsedResponse, nil + } +} + +func (c *Client) getResponse(params url.Values) (*Response, error) { + path := "/discover?" + params.Encode() + response, err := c.client.Get(path) + + if err != nil { + return nil, err + } + + defer response.Body.Close() + parsedResponse, err := c.parseResponse(response) + if err != nil { + return nil, fmt.Errorf("Parsing failed") + } + + return parsedResponse, nil +} + +func (r *Response) IsAnonymous() bool { + return r.UserId < 1 +} diff --git a/go/internal/gitlabnet/discover/client_test.go b/go/internal/gitlabnet/discover/client_test.go new file mode 100644 index 0000000..e88cedd --- /dev/null +++ b/go/internal/gitlabnet/discover/client_test.go @@ -0,0 +1,131 @@ +package discover + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + testConfig *config.Config + requests []testserver.TestRequestHandler +) + +func init() { + testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket} + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("key_id") == "1" { + body := &Response{ + UserId: 2, + Username: "alex-doe", + Name: "Alex Doe", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "jane-doe" { + body := &Response{ + UserId: 1, + Username: "jane-doe", + Name: "Jane Doe", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_message" { + w.WriteHeader(http.StatusForbidden) + body := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_json" { + w.Write([]byte("{ \"message\": \"broken json!\"")) + } else if r.URL.Query().Get("username") == "broken_empty" { + w.WriteHeader(http.StatusForbidden) + } else { + fmt.Fprint(w, "null") + } + }, + }, + } +} + +func TestGetByKeyId(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + result, err := client.GetByKeyId("1") + assert.NoError(t, err) + assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result) +} + +func TestGetByUsername(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + result, err := client.GetByUsername("jane-doe") + assert.NoError(t, err) + assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result) +} + +func TestMissingUser(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + result, err := client.GetByUsername("missing") + assert.NoError(t, err) + assert.True(t, result.IsAnonymous()) +} + +func TestErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + fakeUsername string + expectedError string + }{ + { + desc: "A response with an error message", + fakeUsername: "broken_message", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + fakeUsername: "broken_json", + expectedError: "Parsing failed", + }, + { + desc: "An error response without message", + fakeUsername: "broken_empty", + expectedError: "Internal API error (403)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + resp, err := client.GetByUsername(tc.fakeUsername) + + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + }) + } +} + +func setup(t *testing.T) (*Client, func()) { + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + + client, err := NewClient(testConfig) + require.NoError(t, err) + + return client, cleanup +} diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go new file mode 100644 index 0000000..3bd7c70 --- /dev/null +++ b/go/internal/gitlabnet/socketclient.go @@ -0,0 +1,46 @@ +package gitlabnet + +import ( + "context" + "net" + "net/http" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" +) + +const ( + // We need to set the base URL to something starting with HTTP, the host + // itself is ignored as we're talking over a socket. + socketBaseUrl = "http://unix" + UnixSocketProtocol = "http+unix://" +) + +type GitlabSocketClient struct { + httpClient *http.Client + config *config.Config +} + +func buildSocketClient(config *config.Config) *GitlabSocketClient { + path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol) + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", path) + }, + }, + } + + return &GitlabSocketClient{httpClient: httpClient, config: config} +} + +func (c *GitlabSocketClient) Get(path string) (*http.Response, error) { + path = normalizePath(path) + + request, err := http.NewRequest("GET", socketBaseUrl+path, nil) + if err != nil { + return nil, err + } + + return doRequest(c.httpClient, c.config, request) +} diff --git a/go/internal/gitlabnet/testserver/testserver.go b/go/internal/gitlabnet/testserver/testserver.go new file mode 100644 index 0000000..9640fd7 --- /dev/null +++ b/go/internal/gitlabnet/testserver/testserver.go @@ -0,0 +1,56 @@ +package testserver + +import ( + "io/ioutil" + "log" + "net" + "net/http" + "os" + "path" + "path/filepath" +) + +var ( + tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api") + TestSocket = path.Join(tempDir, "internal.sock") +) + +type TestRequestHandler struct { + Path string + Handler func(w http.ResponseWriter, r *http.Request) +} + +func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) { + if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil { + return nil, err + } + + socketListener, err := net.Listen("unix", TestSocket) + if err != nil { + return nil, err + } + + server := http.Server{ + Handler: buildHandler(handlers), + // We'll put this server through some nasty stuff we don't want + // in our test output + ErrorLog: log.New(ioutil.Discard, "", 0), + } + go server.Serve(socketListener) + + return cleanupSocket, nil +} + +func cleanupSocket() { + os.RemoveAll(tempDir) +} + +func buildHandler(handlers []TestRequestHandler) http.Handler { + h := http.NewServeMux() + + for _, handler := range handlers { + h.HandleFunc(handler.Path, handler.Handler) + } + + return h +} diff --git a/spec/gitlab_shell_gitlab_shell_spec.rb b/spec/gitlab_shell_gitlab_shell_spec.rb index 11692d3..cb3fd9c 100644 --- a/spec/gitlab_shell_gitlab_shell_spec.rb +++ b/spec/gitlab_shell_gitlab_shell_spec.rb @@ -30,12 +30,19 @@ describe 'bin/gitlab-shell' do @server = HTTPUNIXServer.new(BindAddress: tmp_socket_path) @server.mount_proc('/api/v4/internal/discover') do |req, res| - if req.query['key_id'] == '100' || - req.query['user_id'] == '10' || - req.query['username'] == 'someuser' + identifier = req.query['key_id'] || req.query['username'] || req.query['user_id'] + known_identifiers = %w(10 someuser 100) + if known_identifiers.include?(identifier) res.status = 200 res.content_type = 'application/json' res.body = '{"id":1, "name": "Some User", "username": "someuser"}' + elsif identifier == 'broken_message' + res.status = 401 + res.body = '{"message": "Forbidden!"}' + elsif identifier && identifier != 'broken' + res.status = 200 + res.content_type = 'application/json' + res.body = 'null' else res.status = 500 end @@ -145,11 +152,7 @@ describe 'bin/gitlab-shell' do ) end - it_behaves_like 'results with keys' do - before do - pending - end - end + it_behaves_like 'results with keys' it 'outputs "Only ssh allowed"' do _, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-someuser"], env: {}) @@ -157,6 +160,20 @@ describe 'bin/gitlab-shell' do expect(stderr).to eq("Only ssh allowed\n") expect(status).not_to be_success end + + it 'returns an error message when the API call fails with a message' do + _, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken_message"]) + + expect(stderr).to match(/Failed to get username: Forbidden!/) + expect(status).not_to be_success + end + + it 'returns an error message when the API call fails without a message' do + _, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken"]) + + expect(stderr).to match(/Failed to get username: Internal API error \(500\)/) + expect(status).not_to be_success + end end def run!(args, env: {'SSH_CONNECTION' => 'fake'}) |