summaryrefslogtreecommitdiff
path: root/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go
diff options
context:
space:
mode:
Diffstat (limited to 'go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go')
-rw-r--r--go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go55
1 files changed, 19 insertions, 36 deletions
diff --git a/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go b/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go
index 89f2a10..d0a51c0 100644
--- a/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go
+++ b/go/vendor/gitlab.com/gitlab-org/gitaly/client/dial.go
@@ -1,32 +1,36 @@
package client
import (
- "fmt"
- "net"
+ "google.golang.org/grpc/credentials"
+
"net/url"
- "strings"
- "time"
"google.golang.org/grpc"
)
// DefaultDialOpts hold the default DialOptions for connection to Gitaly over UNIX-socket
-var DefaultDialOpts = []grpc.DialOption{
- grpc.WithInsecure(),
-}
+var DefaultDialOpts = []grpc.DialOption{}
// Dial gitaly
func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, error) {
- network, addr, err := parseAddress(rawAddress)
+ canonicalAddress, err := parseAddress(rawAddress)
if err != nil {
return nil, err
}
- connOpts = append(connOpts,
- grpc.WithDialer(func(a string, timeout time.Duration) (net.Conn, error) {
- return net.DialTimeout(network, a, timeout)
- }))
- conn, err := grpc.Dial(addr, connOpts...)
+ if isTLS(rawAddress) {
+ certPool, err := systemCertPool()
+ if err != nil {
+ return nil, err
+ }
+
+ creds := credentials.NewClientTLSFromCert(certPool, "")
+ connOpts = append(connOpts, grpc.WithTransportCredentials(creds))
+ } else {
+ connOpts = append(connOpts, grpc.WithInsecure())
+ }
+
+ conn, err := grpc.Dial(canonicalAddress, connOpts...)
if err != nil {
return nil, err
}
@@ -34,28 +38,7 @@ func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, erro
return conn, nil
}
-func parseAddress(rawAddress string) (network, addr string, err error) {
- // Parsing unix:// URL's with url.Parse does not give the result we want
- // so we do it manually.
- for _, prefix := range []string{"unix://", "unix:"} {
- if strings.HasPrefix(rawAddress, prefix) {
- return "unix", strings.TrimPrefix(rawAddress, prefix), nil
- }
- }
-
+func isTLS(rawAddress string) bool {
u, err := url.Parse(rawAddress)
- if err != nil {
- return "", "", err
- }
-
- if u.Scheme != "tcp" {
- return "", "", fmt.Errorf("unknown scheme: %q", rawAddress)
- }
- if u.Host == "" {
- return "", "", fmt.Errorf("network tcp requires host: %q", rawAddress)
- }
- if u.Path != "" {
- return "", "", fmt.Errorf("network tcp should have no path: %q", rawAddress)
- }
- return "tcp", u.Host, nil
+ return err == nil && u.Scheme == "tls"
}