summaryrefslogtreecommitdiff
path: root/lib/go/thrift/simple_server_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'lib/go/thrift/simple_server_test.go')
-rw-r--r--lib/go/thrift/simple_server_test.go135
1 files changed, 134 insertions, 1 deletions
diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go
index 58149a8e6..b92d50f01 100644
--- a/lib/go/thrift/simple_server_test.go
+++ b/lib/go/thrift/simple_server_test.go
@@ -20,11 +20,17 @@
package thrift
import (
- "testing"
+ "context"
"errors"
+ "net"
"runtime"
+ "sync"
+ "testing"
+ "time"
)
+const networkWaitDuration = 10 * time.Millisecond
+
type mockServerTransport struct {
ListenFunc func() error
AcceptFunc func() (TTransport, error)
@@ -154,3 +160,130 @@ func TestNoHangDuringStopFromDanglingLockAcquireDuringAcceptLoop(t *testing.T) {
runtime.Gosched()
serv.Stop()
}
+
+func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) {
+ ln, err := net.Listen("tcp", "localhost:0")
+
+ if err != nil {
+ t.Fatalf("Failed to listen: %v", err)
+ }
+
+ proc := &mockProcessor{
+ ProcessFunc: func(in, out TProtocol) (bool, TException) {
+ in.ReadMessageBegin(context.Background())
+ return false, nil
+ },
+ }
+
+ trans := &mockServerTransport{
+ ListenFunc: func() error {
+ return nil
+ },
+ AcceptFunc: func() (TTransport, error) {
+ conn, err := ln.Accept()
+ if err != nil {
+ return nil, err
+ }
+
+ return NewTSocketFromConnConf(conn, nil), nil
+ },
+ CloseFunc: func() error {
+ return nil
+ },
+ InterruptFunc: func() error {
+ return ln.Close()
+ },
+ }
+
+ serv := NewTSimpleServer2(proc, trans)
+ go serv.Serve()
+ time.Sleep(networkWaitDuration)
+
+ netConn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil || netConn == nil {
+ t.Fatal("error when dial server")
+ }
+ time.Sleep(networkWaitDuration)
+
+ serverStopTimeout := 50 * time.Millisecond
+ backupServerStopTimeout := ServerStopTimeout
+ t.Cleanup(func() {
+ ServerStopTimeout = backupServerStopTimeout
+ })
+ ServerStopTimeout = serverStopTimeout
+
+ st := time.Now()
+ err = serv.Stop()
+ if err != nil {
+ t.Errorf("error when stop server:%v", err)
+ }
+
+ if elapsed := time.Since(st); elapsed < serverStopTimeout {
+ t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed)
+ }
+}
+
+func TestStopTimeoutWithSocketTimeout(t *testing.T) {
+ ln, err := net.Listen("tcp", "localhost:0")
+
+ if err != nil {
+ t.Fatalf("Failed to listen: %v", err)
+ }
+
+ proc := &mockProcessor{
+ ProcessFunc: func(in, out TProtocol) (bool, TException) {
+ in.ReadMessageBegin(context.Background())
+ return false, nil
+ },
+ }
+
+ conf := &TConfiguration{SocketTimeout: 5 * time.Millisecond}
+ wg := &sync.WaitGroup{}
+ trans := &mockServerTransport{
+ ListenFunc: func() error {
+ return nil
+ },
+ AcceptFunc: func() (TTransport, error) {
+ conn, err := ln.Accept()
+ if err != nil {
+ return nil, err
+ }
+ defer wg.Done()
+ return NewTSocketFromConnConf(conn, conf), nil
+ },
+ CloseFunc: func() error {
+ return nil
+ },
+ InterruptFunc: func() error {
+ return ln.Close()
+ },
+ }
+
+ serv := NewTSimpleServer2(proc, trans)
+ go serv.Serve()
+ time.Sleep(networkWaitDuration)
+
+ wg.Add(1)
+ netConn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil || netConn == nil {
+ t.Fatal("error when dial server")
+ }
+ wg.Wait()
+
+ expectedStopTimeout := time.Second
+ backupServerStopTimeout := ServerStopTimeout
+ t.Cleanup(func() {
+ ServerStopTimeout = backupServerStopTimeout
+ })
+ ServerStopTimeout = expectedStopTimeout
+
+ st := time.Now()
+ err = serv.Stop()
+ if elapsed := time.Since(st); elapsed > expectedStopTimeout/2 {
+ t.Errorf("stop cost more time than socket timeout, socket timeout:%v,server stop timeout:%v,cost time:%v", conf.SocketTimeout, ServerStopTimeout, elapsed)
+ }
+
+ if err != nil {
+ t.Fatalf("error when stop server:%v", err)
+ }
+}