summaryrefslogtreecommitdiff
path: root/lib/go/thrift
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2020-05-26 15:31:20 -0700
committerGitHub <noreply@github.com>2020-05-26 23:31:20 +0100
commit05023e81b264f249affdacad4ebae788b3ada85c (patch)
tree4fba9d044a814649779bed8c61b87cbb871b6080 /lib/go/thrift
parentd28f39fbc7bb9607a150544dd8f73f027c898c9b (diff)
downloadthrift-05023e81b264f249affdacad4ebae788b3ada85c.tar.gz
THRIFT-5214: Connectivity check on go's TSocket
Client: go Implement connectivity check on go's TSocket and TSSLSocket for non-Windows systems. The implementation is inspired by https://github.blog/2020-05-20-three-bugs-in-the-go-mysql-driver/
Diffstat (limited to 'lib/go/thrift')
-rw-r--r--lib/go/thrift/socket.go23
-rw-r--r--lib/go/thrift/socket_conn.go111
-rw-r--r--lib/go/thrift/socket_conn_test.go125
-rw-r--r--lib/go/thrift/socket_unix_conn.go73
-rw-r--r--lib/go/thrift/socket_unix_conn_test.go105
-rw-r--r--lib/go/thrift/socket_windows_conn.go34
-rw-r--r--lib/go/thrift/ssl_socket.go37
7 files changed, 483 insertions, 25 deletions
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index 558818a9a..7c765f56c 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -26,7 +26,7 @@ import (
)
type TSocket struct {
- conn net.Conn
+ conn *socketConn
addr net.Addr
connectTimeout time.Duration
socketTimeout time.Duration
@@ -58,7 +58,7 @@ func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, soTimeo
// Creates a TSocket from an existing net.Conn
func NewTSocketFromConnTimeout(conn net.Conn, connTimeout time.Duration) *TSocket {
- return &TSocket{conn: conn, addr: conn.RemoteAddr(), connectTimeout: connTimeout, socketTimeout: connTimeout}
+ return &TSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), connectTimeout: connTimeout, socketTimeout: connTimeout}
}
// Sets the connect timeout
@@ -89,7 +89,7 @@ func (p *TSocket) pushDeadline(read, write bool) {
// Connects the socket, creating a new socket object if necessary.
func (p *TSocket) Open() error {
- if p.IsOpen() {
+ if p.conn.isValid() {
return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
}
if p.addr == nil {
@@ -102,7 +102,11 @@ func (p *TSocket) Open() error {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
var err error
- if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.connectTimeout); err != nil {
+ if p.conn, err = createSocketConnFromReturn(net.DialTimeout(
+ p.addr.Network(),
+ p.addr.String(),
+ p.connectTimeout,
+ )); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
return nil
@@ -115,10 +119,7 @@ func (p *TSocket) Conn() net.Conn {
// Returns true if the connection is open
func (p *TSocket) IsOpen() bool {
- if p.conn == nil {
- return false
- }
- return true
+ return p.conn.IsOpen()
}
// Closes the socket.
@@ -140,7 +141,7 @@ func (p *TSocket) Addr() net.Addr {
}
func (p *TSocket) Read(buf []byte) (int, error) {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(true, false)
@@ -149,7 +150,7 @@ func (p *TSocket) Read(buf []byte) (int, error) {
}
func (p *TSocket) Write(buf []byte) (int, error) {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(false, true)
@@ -161,7 +162,7 @@ func (p *TSocket) Flush(ctx context.Context) error {
}
func (p *TSocket) Interrupt() error {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return nil
}
return p.conn.Close()
diff --git a/lib/go/thrift/socket_conn.go b/lib/go/thrift/socket_conn.go
new file mode 100644
index 000000000..b0f7b3e69
--- /dev/null
+++ b/lib/go/thrift/socket_conn.go
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "bytes"
+ "io"
+ "net"
+)
+
+// socketConn is a wrapped net.Conn that tries to do connectivity check.
+type socketConn struct {
+ net.Conn
+
+ buf bytes.Buffer
+}
+
+var _ net.Conn = (*socketConn)(nil)
+
+// createSocketConnFromReturn is a language sugar to help create socketConn from
+// return values of functions like net.Dial, tls.Dial, net.Listener.Accept, etc.
+func createSocketConnFromReturn(conn net.Conn, err error) (*socketConn, error) {
+ if err != nil {
+ return nil, err
+ }
+ return &socketConn{
+ Conn: conn,
+ }, nil
+}
+
+// wrapSocketConn wraps an existing net.Conn into *socketConn.
+func wrapSocketConn(conn net.Conn) *socketConn {
+ // In case conn is already wrapped,
+ // return it as-is and avoid double wrapping.
+ if sc, ok := conn.(*socketConn); ok {
+ return sc
+ }
+
+ return &socketConn{
+ Conn: conn,
+ }
+}
+
+// isValid checks whether there's a valid connection.
+//
+// It's nil safe, and returns false if sc itself is nil, or if the underlying
+// connection is nil.
+//
+// It's the same as the previous implementation of TSocket.IsOpen and
+// TSSLSocket.IsOpen before we added connectivity check.
+func (sc *socketConn) isValid() bool {
+ return sc != nil && sc.Conn != nil
+}
+
+// IsOpen checks whether the connection is open.
+//
+// It's nil safe, and returns false if sc itself is nil, or if the underlying
+// connection is nil.
+//
+// Otherwise, it tries to do a connectivity check and returns the result.
+func (sc *socketConn) IsOpen() bool {
+ if !sc.isValid() {
+ return false
+ }
+ return sc.checkConn() == nil
+}
+
+// Read implements io.Reader.
+//
+// On Windows, it behaves the same as the underlying net.Conn.Read.
+//
+// On non-Windows, it treats len(p) == 0 as a connectivity check instead of
+// readability check, which means instead of blocking until there's something to
+// read (readability check), or always return (0, nil) (the default behavior of
+// go's stdlib implementation on non-Windows), it never blocks, and will return
+// an error if the connection is lost.
+func (sc *socketConn) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return 0, sc.read0()
+ }
+
+ n, err = sc.buf.Read(p)
+ if err != nil && err != io.EOF {
+ return
+ }
+ if n == len(p) {
+ return n, nil
+ }
+ // Continue reading from the wire.
+ var newRead int
+ newRead, err = sc.Conn.Read(p[n:])
+ n += newRead
+ return
+}
diff --git a/lib/go/thrift/socket_conn_test.go b/lib/go/thrift/socket_conn_test.go
new file mode 100644
index 000000000..ab924620c
--- /dev/null
+++ b/lib/go/thrift/socket_conn_test.go
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "io"
+ "net"
+ "strings"
+ "testing"
+ "time"
+)
+
+type serverSocketConnCallback func(testing.TB, *socketConn)
+
+func serverSocketConn(tb testing.TB, f serverSocketConnCallback) (net.Listener, error) {
+ tb.Helper()
+
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ for {
+ sc, err := createSocketConnFromReturn(ln.Accept())
+ if err != nil {
+ // This is usually caused by Listener being
+ // closed, not really an error.
+ return
+ }
+ go f(tb, sc)
+ }
+ }()
+ return ln, nil
+}
+
+func writeFully(tb testing.TB, w io.Writer, s string) bool {
+ tb.Helper()
+
+ n, err := io.Copy(w, strings.NewReader(s))
+ if err != nil {
+ tb.Errorf("Failed to write %q: %v", s, err)
+ return false
+ }
+ if int(n) < len(s) {
+ tb.Errorf("Only wrote %d out of %q", n, s)
+ return false
+ }
+ return true
+}
+
+func TestSocketConn(t *testing.T) {
+ const (
+ interval = time.Millisecond * 10
+ first = "hello"
+ second = "world"
+ )
+
+ ln, err := serverSocketConn(
+ t,
+ func(tb testing.TB, sc *socketConn) {
+ defer sc.Close()
+
+ if !writeFully(tb, sc, first) {
+ return
+ }
+ time.Sleep(interval)
+ writeFully(tb, sc, second)
+ },
+ )
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ sc, err := createSocketConnFromReturn(net.Dial("tcp", ln.Addr().String()))
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, 1024)
+
+ n, err := sc.Read(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ read := string(buf[:n])
+ if read != first {
+ t.Errorf("Expected read %q, got %q", first, read)
+ }
+
+ n, err = sc.Read(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ read = string(buf[:n])
+ if read != second {
+ t.Errorf("Expected read %q, got %q", second, read)
+ }
+}
+
+func TestSocketConnNilSafe(t *testing.T) {
+ sc := (*socketConn)(nil)
+ if sc.isValid() {
+ t.Error("Expected false for nil.isValid(), got true")
+ }
+ if sc.IsOpen() {
+ t.Error("Expected false for nil.IsOpen(), got true")
+ }
+}
diff --git a/lib/go/thrift/socket_unix_conn.go b/lib/go/thrift/socket_unix_conn.go
new file mode 100644
index 000000000..f18e0e670
--- /dev/null
+++ b/lib/go/thrift/socket_unix_conn.go
@@ -0,0 +1,73 @@
+// +build !windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "io"
+ "syscall"
+)
+
+func (sc *socketConn) read0() error {
+ return sc.checkConn()
+}
+
+func (sc *socketConn) checkConn() error {
+ syscallConn, ok := sc.Conn.(syscall.Conn)
+ if !ok {
+ // No way to check, return nil
+ return nil
+ }
+ rc, err := syscallConn.SyscallConn()
+ if err != nil {
+ return err
+ }
+
+ var n int
+ var buf [1]byte
+
+ if readErr := rc.Read(func(fd uintptr) bool {
+ n, err = syscall.Read(int(fd), buf[:])
+ return true
+ }); readErr != nil {
+ return readErr
+ }
+
+ if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
+ // This means the connection is still open but we don't have
+ // anything to read right now.
+ return nil
+ }
+
+ if n > 0 {
+ // We got 1 byte,
+ // put it to sc's buf for the next real read to use.
+ sc.buf.Write(buf[:])
+ return nil
+ }
+
+ if err != nil {
+ return err
+ }
+
+ // At this point, it means the other side already closed the connection.
+ return io.EOF
+}
diff --git a/lib/go/thrift/socket_unix_conn_test.go b/lib/go/thrift/socket_unix_conn_test.go
new file mode 100644
index 000000000..3563a259c
--- /dev/null
+++ b/lib/go/thrift/socket_unix_conn_test.go
@@ -0,0 +1,105 @@
+// +build !windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "io"
+ "net"
+ "testing"
+ "time"
+)
+
+func TestSocketConnUnix(t *testing.T) {
+ const (
+ interval = time.Millisecond * 10
+ first = "hello"
+ second = "world"
+ )
+
+ ln, err := serverSocketConn(
+ t,
+ func(tb testing.TB, sc *socketConn) {
+ defer sc.Close()
+
+ time.Sleep(interval)
+ if !writeFully(tb, sc, first) {
+ return
+ }
+ time.Sleep(interval)
+ writeFully(tb, sc, second)
+ },
+ )
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ sc, err := createSocketConnFromReturn(net.Dial("tcp", ln.Addr().String()))
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, 1024)
+
+ if !sc.IsOpen() {
+ t.Error("Expected sc to report open, got false")
+ }
+ n, err := sc.Read(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ read := string(buf[:n])
+ if read != first {
+ t.Errorf("Expected read %q, got %q", first, read)
+ }
+
+ if !sc.IsOpen() {
+ t.Error("Expected sc to report open, got false")
+ }
+ // Do connection check again twice after server already wrote new data,
+ // make sure we correctly buffered the read bytes
+ time.Sleep(interval * 10)
+ if !sc.IsOpen() {
+ t.Error("Expected sc to report open, got false")
+ }
+ if !sc.IsOpen() {
+ t.Error("Expected sc to report open, got false")
+ }
+ if sc.buf.Len() == 0 {
+ t.Error("Expected sc to buffer read bytes, got empty buffer")
+ }
+ n, err = sc.Read(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ read = string(buf[:n])
+ if read != second {
+ t.Errorf("Expected read %q, got %q", second, read)
+ }
+
+ // Now it's supposed to be closed on the server side
+ if err := sc.read0(); err != io.EOF {
+ t.Errorf("Expected to get EOF on read0, got %v", err)
+ }
+ if sc.IsOpen() {
+ t.Error("Expected sc to report not open, got true")
+ }
+}
diff --git a/lib/go/thrift/socket_windows_conn.go b/lib/go/thrift/socket_windows_conn.go
new file mode 100644
index 000000000..679838c3b
--- /dev/null
+++ b/lib/go/thrift/socket_windows_conn.go
@@ -0,0 +1,34 @@
+// +build windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+func (sc *socketConn) read0() error {
+ // On windows, we fallback to the default behavior of reading 0 bytes.
+ var p []byte
+ _, err := sc.Conn.Read(p)
+ return err
+}
+
+func (sc *socketConn) checkConn() error {
+ // On windows, we always return nil for this check.
+ return nil
+}
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index 45bf38a28..661111cdd 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -27,7 +27,7 @@ import (
)
type TSSLSocket struct {
- conn net.Conn
+ conn *socketConn
// hostPort contains host:port (e.g. "asdf.com:12345"). The field is
// only valid if addr is nil.
hostPort string
@@ -62,7 +62,7 @@ func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.D
// Creates a TSSLSocket from an existing net.Conn
func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
- return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
+ return &TSSLSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
}
// Sets the socket timeout
@@ -91,12 +91,18 @@ func (p *TSSLSocket) Open() error {
// If we have a hostname, we need to pass the hostname to tls.Dial for
// certificate hostname checks.
if p.hostPort != "" {
- if p.conn, err = tls.DialWithDialer(&net.Dialer{
- Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil {
+ if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
+ &net.Dialer{
+ Timeout: p.timeout,
+ },
+ "tcp",
+ p.hostPort,
+ p.cfg,
+ )); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
} else {
- if p.IsOpen() {
+ if p.conn.isValid() {
return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
}
if p.addr == nil {
@@ -108,8 +114,14 @@ func (p *TSSLSocket) Open() error {
if len(p.addr.String()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
- if p.conn, err = tls.DialWithDialer(&net.Dialer{
- Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil {
+ if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
+ &net.Dialer{
+ Timeout: p.timeout,
+ },
+ p.addr.Network(),
+ p.addr.String(),
+ p.cfg,
+ )); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
}
@@ -123,10 +135,7 @@ func (p *TSSLSocket) Conn() net.Conn {
// Returns true if the connection is open
func (p *TSSLSocket) IsOpen() bool {
- if p.conn == nil {
- return false
- }
- return true
+ return p.conn.IsOpen()
}
// Closes the socket.
@@ -143,7 +152,7 @@ func (p *TSSLSocket) Close() error {
}
func (p *TSSLSocket) Read(buf []byte) (int, error) {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(true, false)
@@ -152,7 +161,7 @@ func (p *TSSLSocket) Read(buf []byte) (int, error) {
}
func (p *TSSLSocket) Write(buf []byte) (int, error) {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(false, true)
@@ -164,7 +173,7 @@ func (p *TSSLSocket) Flush(ctx context.Context) error {
}
func (p *TSSLSocket) Interrupt() error {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return nil
}
return p.conn.Close()