diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2022-08-08 22:12:40 -0700 |
---|---|---|
committer | Yuxuan 'fishy' Wang <fishywang@gmail.com> | 2022-08-09 17:41:45 -0700 |
commit | bdfde857a802e443a2cab1717744dee8e56cbe76 (patch) | |
tree | 71f2686e775c35667f3fdae4eb83b814ecb77ed1 | |
parent | 7ae180bb1eaea8bdfd6d5714aa90b8445165ff1c (diff) | |
download | thrift-bdfde857a802e443a2cab1717744dee8e56cbe76.tar.gz |
Add a generic sync.Pool wrapper to go library
Since we dropped support of Go 1.18-, use generic to avoid dealing with
type assertions with interface{}/any.
While I'm here, also remove the usages of ioutil, as that's officially
marked as deprecated in Go 1.19.
Client: go
-rw-r--r-- | lib/go/thrift/deserializer.go | 37 | ||||
-rw-r--r-- | lib/go/thrift/framed_transport.go | 10 | ||||
-rw-r--r-- | lib/go/thrift/header_transport.go | 12 | ||||
-rw-r--r-- | lib/go/thrift/http_client.go | 3 | ||||
-rw-r--r-- | lib/go/thrift/http_transport.go | 15 | ||||
-rw-r--r-- | lib/go/thrift/pool.go | 69 | ||||
-rw-r--r-- | lib/go/thrift/pool_test.go (renamed from lib/go/thrift/buf_pool.go) | 51 | ||||
-rw-r--r-- | lib/go/thrift/protocol_test.go | 6 | ||||
-rw-r--r-- | lib/go/thrift/serializer.go | 37 | ||||
-rw-r--r-- | lib/go/thrift/simple_server_test.go | 9 |
10 files changed, 149 insertions, 100 deletions
diff --git a/lib/go/thrift/deserializer.go b/lib/go/thrift/deserializer.go index 2f2468b29..0c68d6b5b 100644 --- a/lib/go/thrift/deserializer.go +++ b/lib/go/thrift/deserializer.go @@ -21,7 +21,6 @@ package thrift import ( "context" - "sync" ) type TDeserializer struct { @@ -81,7 +80,7 @@ func (t *TDeserializer) Read(ctx context.Context, msg TStruct, b []byte) (err er // It must be initialized with either NewTDeserializerPool or // NewTDeserializerPoolSizeFactory. type TDeserializerPool struct { - pool sync.Pool + pool *pool[TDeserializer] } // NewTDeserializerPool creates a new TDeserializerPool. @@ -89,11 +88,7 @@ type TDeserializerPool struct { // NewTDeserializer can be used as the arg here. func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool { return &TDeserializerPool{ - pool: sync.Pool{ - New: func() interface{} { - return f() - }, - }, + pool: newPool(f, nil), } } @@ -104,28 +99,26 @@ func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool { // larger than that. It just dictates the initial size. func NewTDeserializerPoolSizeFactory(size int, factory TProtocolFactory) *TDeserializerPool { return &TDeserializerPool{ - pool: sync.Pool{ - New: func() interface{} { - transport := NewTMemoryBufferLen(size) - protocol := factory.GetProtocol(transport) - - return &TDeserializer{ - Transport: transport, - Protocol: protocol, - } - }, - }, + pool: newPool(func() *TDeserializer { + transport := NewTMemoryBufferLen(size) + protocol := factory.GetProtocol(transport) + + return &TDeserializer{ + Transport: transport, + Protocol: protocol, + } + }, nil), } } func (t *TDeserializerPool) ReadString(ctx context.Context, msg TStruct, s string) error { - d := t.pool.Get().(*TDeserializer) - defer t.pool.Put(d) + d := t.pool.get() + defer t.pool.put(&d) return d.ReadString(ctx, msg, s) } func (t *TDeserializerPool) Read(ctx context.Context, msg TStruct, b []byte) error { - d := t.pool.Get().(*TDeserializer) - defer t.pool.Put(d) + d := t.pool.get() + defer t.pool.put(&d) return d.Read(ctx, msg, b) } diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go index c8bd35e32..e3c323afc 100644 --- a/lib/go/thrift/framed_transport.go +++ b/lib/go/thrift/framed_transport.go @@ -133,7 +133,7 @@ func (p *TFramedTransport) Read(buf []byte) (read int, err error) { // Make sure we return the read buffer back to pool // after we finished reading from it. if p.readBuf != nil && p.readBuf.Len() == 0 { - returnBufToPool(&p.readBuf) + bufPool.put(&p.readBuf) } }() @@ -175,7 +175,7 @@ func (p *TFramedTransport) ReadByte() (c byte, err error) { func (p *TFramedTransport) ensureWriteBufferBeforeWrite() { if p.writeBuf == nil { - p.writeBuf = getBufFromPool() + p.writeBuf = bufPool.get() } } @@ -196,7 +196,7 @@ func (p *TFramedTransport) WriteString(s string) (n int, err error) { } func (p *TFramedTransport) Flush(ctx context.Context) error { - defer returnBufToPool(&p.writeBuf) + defer bufPool.put(&p.writeBuf) size := p.writeBuf.Len() buf := p.buffer[:4] binary.BigEndian.PutUint32(buf, uint32(size)) @@ -215,9 +215,9 @@ func (p *TFramedTransport) Flush(ctx context.Context) error { func (p *TFramedTransport) readFrame() error { if p.readBuf != nil { - returnBufToPool(&p.readBuf) + bufPool.put(&p.readBuf) } - p.readBuf = getBufFromPool() + p.readBuf = bufPool.get() buf := p.buffer[:4] if _, err := io.ReadFull(p.reader, buf); err != nil { diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go index 5ec045482..3aea5a988 100644 --- a/lib/go/thrift/header_transport.go +++ b/lib/go/thrift/header_transport.go @@ -370,7 +370,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error { // Read the frame fully into frameBuffer. if t.frameBuffer == nil { - t.frameBuffer = getBufFromPool() + t.frameBuffer = bufPool.get() } _, err = io.CopyN(t.frameBuffer, t.reader, int64(frameSize)) if err != nil { @@ -407,7 +407,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error { // It closes frameReader, and also resets frame related states. func (t *THeaderTransport) endOfFrame() error { defer func() { - returnBufToPool(&t.frameBuffer) + bufPool.put(&t.frameBuffer) t.frameReader = nil }() return t.frameReader.Close() @@ -572,7 +572,7 @@ func (t *THeaderTransport) Read(p []byte) (read int, err error) { // You need to call Flush to actually write them to the transport. func (t *THeaderTransport) Write(p []byte) (int, error) { if t.writeBuffer == nil { - t.writeBuffer = getBufFromPool() + t.writeBuffer = bufPool.get() } return t.writeBuffer.Write(p) } @@ -583,7 +583,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error { return nil } - defer returnBufToPool(&t.writeBuffer) + defer bufPool.put(&t.writeBuffer) switch t.clientType { default: @@ -633,8 +633,8 @@ func (t *THeaderTransport) Flush(ctx context.Context) error { } } - payload := getBufFromPool() - defer returnBufToPool(&payload) + payload := bufPool.get() + defer bufPool.put(&payload) meta := headerMeta{ MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask, SequenceID: t.SequenceID, diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go index ce62c96a2..a0f206653 100644 --- a/lib/go/thrift/http_client.go +++ b/lib/go/thrift/http_client.go @@ -24,7 +24,6 @@ import ( "context" "errors" "io" - "io/ioutil" "net/http" "net/url" "strconv" @@ -136,7 +135,7 @@ func (p *THttpClient) closeResponse() error { // reused. Errors are being ignored here because if the connection is invalid // and this fails for some reason, the Close() method will do any remaining // cleanup. - io.Copy(ioutil.Discard, p.response.Body) + io.Copy(io.Discard, p.response.Body) err = p.response.Body.Close() } diff --git a/lib/go/thrift/http_transport.go b/lib/go/thrift/http_transport.go index bc6922762..c84aba953 100644 --- a/lib/go/thrift/http_transport.go +++ b/lib/go/thrift/http_transport.go @@ -24,7 +24,6 @@ import ( "io" "net/http" "strings" - "sync" ) // NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function @@ -41,11 +40,9 @@ func NewThriftHandlerFunc(processor TProcessor, // gz transparently compresses the HTTP response if the client supports it. func gz(handler http.HandlerFunc) http.HandlerFunc { - sp := &sync.Pool{ - New: func() interface{} { - return gzip.NewWriter(nil) - }, - } + sp := newPool(func() *gzip.Writer { + return gzip.NewWriter(nil) + }, nil) return func(w http.ResponseWriter, r *http.Request) { if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { @@ -53,11 +50,11 @@ func gz(handler http.HandlerFunc) http.HandlerFunc { return } w.Header().Set("Content-Encoding", "gzip") - gz := sp.Get().(*gzip.Writer) + gz := sp.get() gz.Reset(w) defer func() { - _ = gz.Close() - sp.Put(gz) + gz.Close() + sp.put(&gz) }() gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w} handler(gzw, r) diff --git a/lib/go/thrift/pool.go b/lib/go/thrift/pool.go new file mode 100644 index 000000000..1d623d422 --- /dev/null +++ b/lib/go/thrift/pool.go @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "sync" +) + +// pool is a generic sync.Pool wrapper with bells and whistles. +type pool[T any] struct { + pool sync.Pool + reset func(*T) +} + +// newPool creates a new pool. +// +// Both generate and reset are optional. +// Default generate is just new(T), +// When reset is nil we don't do any additional resetting when calling get. +func newPool[T any](generate func() *T, reset func(*T)) *pool[T] { + if generate == nil { + generate = func() *T { + return new(T) + } + } + return &pool[T]{ + pool: sync.Pool{ + New: func() interface{} { + return generate() + }, + }, + reset: reset, + } +} + +func (p *pool[T]) get() *T { + r := p.pool.Get().(*T) + if p.reset != nil { + p.reset(r) + } + return r +} + +func (p *pool[T]) put(r **T) { + p.pool.Put(*r) + *r = nil +} + +var bufPool = newPool(nil, func(buf *bytes.Buffer) { + buf.Reset() +}) diff --git a/lib/go/thrift/buf_pool.go b/lib/go/thrift/pool_test.go index 9708ea0e1..c717e1d6e 100644 --- a/lib/go/thrift/buf_pool.go +++ b/lib/go/thrift/pool_test.go @@ -20,33 +20,32 @@ package thrift import ( - "bytes" - "sync" + "testing" + "testing/quick" ) -var bufPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, -} - -// getBufFromPool gets a buffer out of the pool and guarantees that it's reset -// before return. -func getBufFromPool() *bytes.Buffer { - buf := bufPool.Get().(*bytes.Buffer) - buf.Reset() - return buf -} +type poolTest int -// returnBufToPool returns a buffer to the pool, and sets it to nil to avoid -// accidental usage after it's returned. -// -// You usually want to use it this way: -// -// buf := getBufFromPool() -// defer returnBufToPool(&buf) -// // use buf -func returnBufToPool(buf **bytes.Buffer) { - bufPool.Put(*buf) - *buf = nil +func TestPoolReset(t *testing.T) { + p := newPool(nil, func(elem *poolTest) { + *elem = 0 + }) + f := func(i int) (passed bool) { + pt := p.get() + defer func() { + p.put(&pt) + if pt != nil { + t.Errorf("Expected pt to be nil after put, got %#v", pt) + passed = false + } + }() + if *pt != 0 { + t.Errorf("Expected *pt to be reset to 0 after get, got %d", *pt) + } + *pt = poolTest(i) + return !t.Failed() + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } } diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go index caac78e99..d66dc65c2 100644 --- a/lib/go/thrift/protocol_test.go +++ b/lib/go/thrift/protocol_test.go @@ -22,7 +22,7 @@ package thrift import ( "bytes" "context" - "io/ioutil" + "io" "math" "net" "net/http" @@ -60,7 +60,7 @@ type HTTPEchoServer struct{} type HTTPHeaderEchoServer struct{} func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - buf, err := ioutil.ReadAll(req.Body) + buf, err := io.ReadAll(req.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write(buf) @@ -71,7 +71,7 @@ func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - buf, err := ioutil.ReadAll(req.Body) + buf, err := io.ReadAll(req.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write(buf) diff --git a/lib/go/thrift/serializer.go b/lib/go/thrift/serializer.go index f4d920186..53a674e7b 100644 --- a/lib/go/thrift/serializer.go +++ b/lib/go/thrift/serializer.go @@ -21,7 +21,6 @@ package thrift import ( "context" - "sync" ) type TSerializer struct { @@ -92,7 +91,7 @@ func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err err // It must be initialized with either NewTSerializerPool or // NewTSerializerPoolSizeFactory. type TSerializerPool struct { - pool sync.Pool + pool *pool[TSerializer] } // NewTSerializerPool creates a new TSerializerPool. @@ -100,11 +99,7 @@ type TSerializerPool struct { // NewTSerializer can be used as the arg here. func NewTSerializerPool(f func() *TSerializer) *TSerializerPool { return &TSerializerPool{ - pool: sync.Pool{ - New: func() interface{} { - return f() - }, - }, + pool: newPool(f, nil), } } @@ -115,28 +110,26 @@ func NewTSerializerPool(f func() *TSerializer) *TSerializerPool { // larger than that. It just dictates the initial size. func NewTSerializerPoolSizeFactory(size int, factory TProtocolFactory) *TSerializerPool { return &TSerializerPool{ - pool: sync.Pool{ - New: func() interface{} { - transport := NewTMemoryBufferLen(size) - protocol := factory.GetProtocol(transport) - - return &TSerializer{ - Transport: transport, - Protocol: protocol, - } - }, - }, + pool: newPool(func() *TSerializer { + transport := NewTMemoryBufferLen(size) + protocol := factory.GetProtocol(transport) + + return &TSerializer{ + Transport: transport, + Protocol: protocol, + } + }, nil), } } func (t *TSerializerPool) WriteString(ctx context.Context, msg TStruct) (string, error) { - s := t.pool.Get().(*TSerializer) - defer t.pool.Put(s) + s := t.pool.get() + defer t.pool.put(&s) return s.WriteString(ctx, msg) } func (t *TSerializerPool) Write(ctx context.Context, msg TStruct) ([]byte, error) { - s := t.pool.Get().(*TSerializer) - defer t.pool.Put(s) + s := t.pool.get() + defer t.pool.put(&s) return s.Write(ctx, msg) } diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go index e0cf151b9..f3a59ee18 100644 --- a/lib/go/thrift/simple_server_test.go +++ b/lib/go/thrift/simple_server_test.go @@ -201,11 +201,11 @@ func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) { netConn, err := net.Dial("tcp", ln.Addr().String()) if err != nil || netConn == nil { - t.Fatal("error when dial server") + t.Fatalf("error when dial server: %v", err) } time.Sleep(networkWaitDuration) - serverStopTimeout := 50 * time.Millisecond + const serverStopTimeout = 50 * time.Millisecond backupServerStopTimeout := ServerStopTimeout t.Cleanup(func() { ServerStopTimeout = backupServerStopTimeout @@ -213,13 +213,12 @@ func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) { ServerStopTimeout = serverStopTimeout st := time.Now() - err = serv.Stop() - if err != nil { + if err := serv.Stop(); err != nil { t.Errorf("error when stop server:%v", err) } if elapsed := time.Since(st); elapsed < serverStopTimeout { - t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed) + t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", serverStopTimeout, elapsed) } } |