summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/go/thrift/socket.go10
-rw-r--r--lib/go/thrift/socket_conn.go13
-rw-r--r--lib/go/thrift/ssl_socket.go10
3 files changed, 14 insertions, 19 deletions
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index eeac4f1a4..cba7c0f77 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -194,15 +194,7 @@ func (p *TSocket) IsOpen() bool {
// Closes the socket.
func (p *TSocket) Close() error {
- // Close the socket
- if p.conn != nil {
- err := p.conn.Close()
- if err != nil {
- return err
- }
- p.conn = nil
- }
- return nil
+ return p.conn.Close()
}
//Returns the remote address of the socket.
diff --git a/lib/go/thrift/socket_conn.go b/lib/go/thrift/socket_conn.go
index c1cc30c6c..5619d9626 100644
--- a/lib/go/thrift/socket_conn.go
+++ b/lib/go/thrift/socket_conn.go
@@ -21,6 +21,7 @@ package thrift
import (
"net"
+ "sync/atomic"
)
// socketConn is a wrapped net.Conn that tries to do connectivity check.
@@ -28,6 +29,7 @@ type socketConn struct {
net.Conn
buffer [1]byte
+ closed int32
}
var _ net.Conn = (*socketConn)(nil)
@@ -64,7 +66,7 @@ func wrapSocketConn(conn net.Conn) *socketConn {
// It's the same as the previous implementation of TSocket.IsOpen and
// TSSLSocket.IsOpen before we added connectivity check.
func (sc *socketConn) isValid() bool {
- return sc != nil && sc.Conn != nil
+ return sc != nil && sc.Conn != nil && atomic.LoadInt32(&sc.closed) == 0
}
// IsOpen checks whether the connection is open.
@@ -100,3 +102,12 @@ func (sc *socketConn) Read(p []byte) (n int, err error) {
return sc.Conn.Read(p)
}
+
+func (sc *socketConn) Close() error {
+ if !sc.isValid() {
+ // Already closed
+ return net.ErrClosed
+ }
+ atomic.StoreInt32(&sc.closed, 1)
+ return sc.Conn.Close()
+}
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index bee1097d4..d7ba415ec 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -220,15 +220,7 @@ func (p *TSSLSocket) IsOpen() bool {
// Closes the socket.
func (p *TSSLSocket) Close() error {
- // Close the socket
- if p.conn != nil {
- err := p.conn.Close()
- if err != nil {
- return err
- }
- p.conn = nil
- }
- return nil
+ return p.conn.Close()
}
func (p *TSSLSocket) Read(buf []byte) (int, error) {