diff options
Diffstat (limited to 'go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go')
-rw-r--r-- | go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go | 55 |
1 files changed, 19 insertions, 36 deletions
diff --git a/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go b/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go index 89f2a10..d0a51c0 100644 --- a/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go +++ b/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go @@ -1,32 +1,36 @@ package client import ( - "fmt" - "net" + "google.golang.org/grpc/credentials" + "net/url" - "strings" - "time" "google.golang.org/grpc" ) // DefaultDialOpts hold the default DialOptions for connection to Gitaly over UNIX-socket -var DefaultDialOpts = []grpc.DialOption{ - grpc.WithInsecure(), -} +var DefaultDialOpts = []grpc.DialOption{} // Dial gitaly func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, error) { - network, addr, err := parseAddress(rawAddress) + canonicalAddress, err := parseAddress(rawAddress) if err != nil { return nil, err } - connOpts = append(connOpts, - grpc.WithDialer(func(a string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout(network, a, timeout) - })) - conn, err := grpc.Dial(addr, connOpts...) + if isTLS(rawAddress) { + certPool, err := systemCertPool() + if err != nil { + return nil, err + } + + creds := credentials.NewClientTLSFromCert(certPool, "") + connOpts = append(connOpts, grpc.WithTransportCredentials(creds)) + } else { + connOpts = append(connOpts, grpc.WithInsecure()) + } + + conn, err := grpc.Dial(canonicalAddress, connOpts...) if err != nil { return nil, err } @@ -34,28 +38,7 @@ func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, erro return conn, nil } -func parseAddress(rawAddress string) (network, addr string, err error) { - // Parsing unix:// URL's with url.Parse does not give the result we want - // so we do it manually. - for _, prefix := range []string{"unix://", "unix:"} { - if strings.HasPrefix(rawAddress, prefix) { - return "unix", strings.TrimPrefix(rawAddress, prefix), nil - } - } - +func isTLS(rawAddress string) bool { u, err := url.Parse(rawAddress) - if err != nil { - return "", "", err - } - - if u.Scheme != "tcp" { - return "", "", fmt.Errorf("unknown scheme: %q", rawAddress) - } - if u.Host == "" { - return "", "", fmt.Errorf("network tcp requires host: %q", rawAddress) - } - if u.Path != "" { - return "", "", fmt.Errorf("network tcp should have no path: %q", rawAddress) - } - return "tcp", u.Host, nil + return err == nil && u.Scheme == "tls" } |