diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2020-05-26 15:31:20 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-26 23:31:20 +0100 |
commit | 05023e81b264f249affdacad4ebae788b3ada85c (patch) | |
tree | 4fba9d044a814649779bed8c61b87cbb871b6080 /lib/go/thrift | |
parent | d28f39fbc7bb9607a150544dd8f73f027c898c9b (diff) | |
download | thrift-05023e81b264f249affdacad4ebae788b3ada85c.tar.gz |
THRIFT-5214: Connectivity check on go's TSocket
Client: go
Implement connectivity check on go's TSocket and TSSLSocket for
non-Windows systems.
The implementation is inspired by
https://github.blog/2020-05-20-three-bugs-in-the-go-mysql-driver/
Diffstat (limited to 'lib/go/thrift')
-rw-r--r-- | lib/go/thrift/socket.go | 23 | ||||
-rw-r--r-- | lib/go/thrift/socket_conn.go | 111 | ||||
-rw-r--r-- | lib/go/thrift/socket_conn_test.go | 125 | ||||
-rw-r--r-- | lib/go/thrift/socket_unix_conn.go | 73 | ||||
-rw-r--r-- | lib/go/thrift/socket_unix_conn_test.go | 105 | ||||
-rw-r--r-- | lib/go/thrift/socket_windows_conn.go | 34 | ||||
-rw-r--r-- | lib/go/thrift/ssl_socket.go | 37 |
7 files changed, 483 insertions, 25 deletions
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go index 558818a9a..7c765f56c 100644 --- a/lib/go/thrift/socket.go +++ b/lib/go/thrift/socket.go @@ -26,7 +26,7 @@ import ( ) type TSocket struct { - conn net.Conn + conn *socketConn addr net.Addr connectTimeout time.Duration socketTimeout time.Duration @@ -58,7 +58,7 @@ func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, soTimeo // Creates a TSocket from an existing net.Conn func NewTSocketFromConnTimeout(conn net.Conn, connTimeout time.Duration) *TSocket { - return &TSocket{conn: conn, addr: conn.RemoteAddr(), connectTimeout: connTimeout, socketTimeout: connTimeout} + return &TSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), connectTimeout: connTimeout, socketTimeout: connTimeout} } // Sets the connect timeout @@ -89,7 +89,7 @@ func (p *TSocket) pushDeadline(read, write bool) { // Connects the socket, creating a new socket object if necessary. func (p *TSocket) Open() error { - if p.IsOpen() { + if p.conn.isValid() { return NewTTransportException(ALREADY_OPEN, "Socket already connected.") } if p.addr == nil { @@ -102,7 +102,11 @@ func (p *TSocket) Open() error { return NewTTransportException(NOT_OPEN, "Cannot open bad address.") } var err error - if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.connectTimeout); err != nil { + if p.conn, err = createSocketConnFromReturn(net.DialTimeout( + p.addr.Network(), + p.addr.String(), + p.connectTimeout, + )); err != nil { return NewTTransportException(NOT_OPEN, err.Error()) } return nil @@ -115,10 +119,7 @@ func (p *TSocket) Conn() net.Conn { // Returns true if the connection is open func (p *TSocket) IsOpen() bool { - if p.conn == nil { - return false - } - return true + return p.conn.IsOpen() } // Closes the socket. @@ -140,7 +141,7 @@ func (p *TSocket) Addr() net.Addr { } func (p *TSocket) Read(buf []byte) (int, error) { - if !p.IsOpen() { + if !p.conn.isValid() { return 0, NewTTransportException(NOT_OPEN, "Connection not open") } p.pushDeadline(true, false) @@ -149,7 +150,7 @@ func (p *TSocket) Read(buf []byte) (int, error) { } func (p *TSocket) Write(buf []byte) (int, error) { - if !p.IsOpen() { + if !p.conn.isValid() { return 0, NewTTransportException(NOT_OPEN, "Connection not open") } p.pushDeadline(false, true) @@ -161,7 +162,7 @@ func (p *TSocket) Flush(ctx context.Context) error { } func (p *TSocket) Interrupt() error { - if !p.IsOpen() { + if !p.conn.isValid() { return nil } return p.conn.Close() diff --git a/lib/go/thrift/socket_conn.go b/lib/go/thrift/socket_conn.go new file mode 100644 index 000000000..b0f7b3e69 --- /dev/null +++ b/lib/go/thrift/socket_conn.go @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "io" + "net" +) + +// socketConn is a wrapped net.Conn that tries to do connectivity check. +type socketConn struct { + net.Conn + + buf bytes.Buffer +} + +var _ net.Conn = (*socketConn)(nil) + +// createSocketConnFromReturn is a language sugar to help create socketConn from +// return values of functions like net.Dial, tls.Dial, net.Listener.Accept, etc. +func createSocketConnFromReturn(conn net.Conn, err error) (*socketConn, error) { + if err != nil { + return nil, err + } + return &socketConn{ + Conn: conn, + }, nil +} + +// wrapSocketConn wraps an existing net.Conn into *socketConn. +func wrapSocketConn(conn net.Conn) *socketConn { + // In case conn is already wrapped, + // return it as-is and avoid double wrapping. + if sc, ok := conn.(*socketConn); ok { + return sc + } + + return &socketConn{ + Conn: conn, + } +} + +// isValid checks whether there's a valid connection. +// +// It's nil safe, and returns false if sc itself is nil, or if the underlying +// connection is nil. +// +// It's the same as the previous implementation of TSocket.IsOpen and +// TSSLSocket.IsOpen before we added connectivity check. +func (sc *socketConn) isValid() bool { + return sc != nil && sc.Conn != nil +} + +// IsOpen checks whether the connection is open. +// +// It's nil safe, and returns false if sc itself is nil, or if the underlying +// connection is nil. +// +// Otherwise, it tries to do a connectivity check and returns the result. +func (sc *socketConn) IsOpen() bool { + if !sc.isValid() { + return false + } + return sc.checkConn() == nil +} + +// Read implements io.Reader. +// +// On Windows, it behaves the same as the underlying net.Conn.Read. +// +// On non-Windows, it treats len(p) == 0 as a connectivity check instead of +// readability check, which means instead of blocking until there's something to +// read (readability check), or always return (0, nil) (the default behavior of +// go's stdlib implementation on non-Windows), it never blocks, and will return +// an error if the connection is lost. +func (sc *socketConn) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, sc.read0() + } + + n, err = sc.buf.Read(p) + if err != nil && err != io.EOF { + return + } + if n == len(p) { + return n, nil + } + // Continue reading from the wire. + var newRead int + newRead, err = sc.Conn.Read(p[n:]) + n += newRead + return +} diff --git a/lib/go/thrift/socket_conn_test.go b/lib/go/thrift/socket_conn_test.go new file mode 100644 index 000000000..ab924620c --- /dev/null +++ b/lib/go/thrift/socket_conn_test.go @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "io" + "net" + "strings" + "testing" + "time" +) + +type serverSocketConnCallback func(testing.TB, *socketConn) + +func serverSocketConn(tb testing.TB, f serverSocketConnCallback) (net.Listener, error) { + tb.Helper() + + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, err + } + go func() { + for { + sc, err := createSocketConnFromReturn(ln.Accept()) + if err != nil { + // This is usually caused by Listener being + // closed, not really an error. + return + } + go f(tb, sc) + } + }() + return ln, nil +} + +func writeFully(tb testing.TB, w io.Writer, s string) bool { + tb.Helper() + + n, err := io.Copy(w, strings.NewReader(s)) + if err != nil { + tb.Errorf("Failed to write %q: %v", s, err) + return false + } + if int(n) < len(s) { + tb.Errorf("Only wrote %d out of %q", n, s) + return false + } + return true +} + +func TestSocketConn(t *testing.T) { + const ( + interval = time.Millisecond * 10 + first = "hello" + second = "world" + ) + + ln, err := serverSocketConn( + t, + func(tb testing.TB, sc *socketConn) { + defer sc.Close() + + if !writeFully(tb, sc, first) { + return + } + time.Sleep(interval) + writeFully(tb, sc, second) + }, + ) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + sc, err := createSocketConnFromReturn(net.Dial("tcp", ln.Addr().String())) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + + n, err := sc.Read(buf) + if err != nil { + t.Fatal(err) + } + read := string(buf[:n]) + if read != first { + t.Errorf("Expected read %q, got %q", first, read) + } + + n, err = sc.Read(buf) + if err != nil { + t.Fatal(err) + } + read = string(buf[:n]) + if read != second { + t.Errorf("Expected read %q, got %q", second, read) + } +} + +func TestSocketConnNilSafe(t *testing.T) { + sc := (*socketConn)(nil) + if sc.isValid() { + t.Error("Expected false for nil.isValid(), got true") + } + if sc.IsOpen() { + t.Error("Expected false for nil.IsOpen(), got true") + } +} diff --git a/lib/go/thrift/socket_unix_conn.go b/lib/go/thrift/socket_unix_conn.go new file mode 100644 index 000000000..f18e0e670 --- /dev/null +++ b/lib/go/thrift/socket_unix_conn.go @@ -0,0 +1,73 @@ +// +build !windows + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "io" + "syscall" +) + +func (sc *socketConn) read0() error { + return sc.checkConn() +} + +func (sc *socketConn) checkConn() error { + syscallConn, ok := sc.Conn.(syscall.Conn) + if !ok { + // No way to check, return nil + return nil + } + rc, err := syscallConn.SyscallConn() + if err != nil { + return err + } + + var n int + var buf [1]byte + + if readErr := rc.Read(func(fd uintptr) bool { + n, err = syscall.Read(int(fd), buf[:]) + return true + }); readErr != nil { + return readErr + } + + if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK { + // This means the connection is still open but we don't have + // anything to read right now. + return nil + } + + if n > 0 { + // We got 1 byte, + // put it to sc's buf for the next real read to use. + sc.buf.Write(buf[:]) + return nil + } + + if err != nil { + return err + } + + // At this point, it means the other side already closed the connection. + return io.EOF +} diff --git a/lib/go/thrift/socket_unix_conn_test.go b/lib/go/thrift/socket_unix_conn_test.go new file mode 100644 index 000000000..3563a259c --- /dev/null +++ b/lib/go/thrift/socket_unix_conn_test.go @@ -0,0 +1,105 @@ +// +build !windows + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "io" + "net" + "testing" + "time" +) + +func TestSocketConnUnix(t *testing.T) { + const ( + interval = time.Millisecond * 10 + first = "hello" + second = "world" + ) + + ln, err := serverSocketConn( + t, + func(tb testing.TB, sc *socketConn) { + defer sc.Close() + + time.Sleep(interval) + if !writeFully(tb, sc, first) { + return + } + time.Sleep(interval) + writeFully(tb, sc, second) + }, + ) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + sc, err := createSocketConnFromReturn(net.Dial("tcp", ln.Addr().String())) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + + if !sc.IsOpen() { + t.Error("Expected sc to report open, got false") + } + n, err := sc.Read(buf) + if err != nil { + t.Fatal(err) + } + read := string(buf[:n]) + if read != first { + t.Errorf("Expected read %q, got %q", first, read) + } + + if !sc.IsOpen() { + t.Error("Expected sc to report open, got false") + } + // Do connection check again twice after server already wrote new data, + // make sure we correctly buffered the read bytes + time.Sleep(interval * 10) + if !sc.IsOpen() { + t.Error("Expected sc to report open, got false") + } + if !sc.IsOpen() { + t.Error("Expected sc to report open, got false") + } + if sc.buf.Len() == 0 { + t.Error("Expected sc to buffer read bytes, got empty buffer") + } + n, err = sc.Read(buf) + if err != nil { + t.Fatal(err) + } + read = string(buf[:n]) + if read != second { + t.Errorf("Expected read %q, got %q", second, read) + } + + // Now it's supposed to be closed on the server side + if err := sc.read0(); err != io.EOF { + t.Errorf("Expected to get EOF on read0, got %v", err) + } + if sc.IsOpen() { + t.Error("Expected sc to report not open, got true") + } +} diff --git a/lib/go/thrift/socket_windows_conn.go b/lib/go/thrift/socket_windows_conn.go new file mode 100644 index 000000000..679838c3b --- /dev/null +++ b/lib/go/thrift/socket_windows_conn.go @@ -0,0 +1,34 @@ +// +build windows + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +func (sc *socketConn) read0() error { + // On windows, we fallback to the default behavior of reading 0 bytes. + var p []byte + _, err := sc.Conn.Read(p) + return err +} + +func (sc *socketConn) checkConn() error { + // On windows, we always return nil for this check. + return nil +} diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go index 45bf38a28..661111cdd 100644 --- a/lib/go/thrift/ssl_socket.go +++ b/lib/go/thrift/ssl_socket.go @@ -27,7 +27,7 @@ import ( ) type TSSLSocket struct { - conn net.Conn + conn *socketConn // hostPort contains host:port (e.g. "asdf.com:12345"). The field is // only valid if addr is nil. hostPort string @@ -62,7 +62,7 @@ func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.D // Creates a TSSLSocket from an existing net.Conn func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket { - return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg} + return &TSSLSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg} } // Sets the socket timeout @@ -91,12 +91,18 @@ func (p *TSSLSocket) Open() error { // If we have a hostname, we need to pass the hostname to tls.Dial for // certificate hostname checks. if p.hostPort != "" { - if p.conn, err = tls.DialWithDialer(&net.Dialer{ - Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil { + if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer( + &net.Dialer{ + Timeout: p.timeout, + }, + "tcp", + p.hostPort, + p.cfg, + )); err != nil { return NewTTransportException(NOT_OPEN, err.Error()) } } else { - if p.IsOpen() { + if p.conn.isValid() { return NewTTransportException(ALREADY_OPEN, "Socket already connected.") } if p.addr == nil { @@ -108,8 +114,14 @@ func (p *TSSLSocket) Open() error { if len(p.addr.String()) == 0 { return NewTTransportException(NOT_OPEN, "Cannot open bad address.") } - if p.conn, err = tls.DialWithDialer(&net.Dialer{ - Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil { + if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer( + &net.Dialer{ + Timeout: p.timeout, + }, + p.addr.Network(), + p.addr.String(), + p.cfg, + )); err != nil { return NewTTransportException(NOT_OPEN, err.Error()) } } @@ -123,10 +135,7 @@ func (p *TSSLSocket) Conn() net.Conn { // Returns true if the connection is open func (p *TSSLSocket) IsOpen() bool { - if p.conn == nil { - return false - } - return true + return p.conn.IsOpen() } // Closes the socket. @@ -143,7 +152,7 @@ func (p *TSSLSocket) Close() error { } func (p *TSSLSocket) Read(buf []byte) (int, error) { - if !p.IsOpen() { + if !p.conn.isValid() { return 0, NewTTransportException(NOT_OPEN, "Connection not open") } p.pushDeadline(true, false) @@ -152,7 +161,7 @@ func (p *TSSLSocket) Read(buf []byte) (int, error) { } func (p *TSSLSocket) Write(buf []byte) (int, error) { - if !p.IsOpen() { + if !p.conn.isValid() { return 0, NewTTransportException(NOT_OPEN, "Connection not open") } p.pushDeadline(false, true) @@ -164,7 +173,7 @@ func (p *TSSLSocket) Flush(ctx context.Context) error { } func (p *TSSLSocket) Interrupt() error { - if !p.IsOpen() { + if !p.conn.isValid() { return nil } return p.conn.Close() |