diff options
Diffstat (limited to 'lib/go/thrift/simple_server_test.go')
-rw-r--r-- | lib/go/thrift/simple_server_test.go | 135 |
1 files changed, 134 insertions, 1 deletions
diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go index 58149a8e6..b92d50f01 100644 --- a/lib/go/thrift/simple_server_test.go +++ b/lib/go/thrift/simple_server_test.go @@ -20,11 +20,17 @@ package thrift import ( - "testing" + "context" "errors" + "net" "runtime" + "sync" + "testing" + "time" ) +const networkWaitDuration = 10 * time.Millisecond + type mockServerTransport struct { ListenFunc func() error AcceptFunc func() (TTransport, error) @@ -154,3 +160,130 @@ func TestNoHangDuringStopFromDanglingLockAcquireDuringAcceptLoop(t *testing.T) { runtime.Gosched() serv.Stop() } + +func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + + proc := &mockProcessor{ + ProcessFunc: func(in, out TProtocol) (bool, TException) { + in.ReadMessageBegin(context.Background()) + return false, nil + }, + } + + trans := &mockServerTransport{ + ListenFunc: func() error { + return nil + }, + AcceptFunc: func() (TTransport, error) { + conn, err := ln.Accept() + if err != nil { + return nil, err + } + + return NewTSocketFromConnConf(conn, nil), nil + }, + CloseFunc: func() error { + return nil + }, + InterruptFunc: func() error { + return ln.Close() + }, + } + + serv := NewTSimpleServer2(proc, trans) + go serv.Serve() + time.Sleep(networkWaitDuration) + + netConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil || netConn == nil { + t.Fatal("error when dial server") + } + time.Sleep(networkWaitDuration) + + serverStopTimeout := 50 * time.Millisecond + backupServerStopTimeout := ServerStopTimeout + t.Cleanup(func() { + ServerStopTimeout = backupServerStopTimeout + }) + ServerStopTimeout = serverStopTimeout + + st := time.Now() + err = serv.Stop() + if err != nil { + t.Errorf("error when stop server:%v", err) + } + + if elapsed := time.Since(st); elapsed < serverStopTimeout { + t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed) + } +} + +func TestStopTimeoutWithSocketTimeout(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + + proc := &mockProcessor{ + ProcessFunc: func(in, out TProtocol) (bool, TException) { + in.ReadMessageBegin(context.Background()) + return false, nil + }, + } + + conf := &TConfiguration{SocketTimeout: 5 * time.Millisecond} + wg := &sync.WaitGroup{} + trans := &mockServerTransport{ + ListenFunc: func() error { + return nil + }, + AcceptFunc: func() (TTransport, error) { + conn, err := ln.Accept() + if err != nil { + return nil, err + } + defer wg.Done() + return NewTSocketFromConnConf(conn, conf), nil + }, + CloseFunc: func() error { + return nil + }, + InterruptFunc: func() error { + return ln.Close() + }, + } + + serv := NewTSimpleServer2(proc, trans) + go serv.Serve() + time.Sleep(networkWaitDuration) + + wg.Add(1) + netConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil || netConn == nil { + t.Fatal("error when dial server") + } + wg.Wait() + + expectedStopTimeout := time.Second + backupServerStopTimeout := ServerStopTimeout + t.Cleanup(func() { + ServerStopTimeout = backupServerStopTimeout + }) + ServerStopTimeout = expectedStopTimeout + + st := time.Now() + err = serv.Stop() + if elapsed := time.Since(st); elapsed > expectedStopTimeout/2 { + t.Errorf("stop cost more time than socket timeout, socket timeout:%v,server stop timeout:%v,cost time:%v", conf.SocketTimeout, ServerStopTimeout, elapsed) + } + + if err != nil { + t.Fatalf("error when stop server:%v", err) + } +} |