summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2022-08-08 22:12:40 -0700
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2022-08-09 17:41:45 -0700
commitbdfde857a802e443a2cab1717744dee8e56cbe76 (patch)
tree71f2686e775c35667f3fdae4eb83b814ecb77ed1
parent7ae180bb1eaea8bdfd6d5714aa90b8445165ff1c (diff)
downloadthrift-bdfde857a802e443a2cab1717744dee8e56cbe76.tar.gz
Add a generic sync.Pool wrapper to go library
Since we dropped support of Go 1.18-, use generic to avoid dealing with type assertions with interface{}/any. While I'm here, also remove the usages of ioutil, as that's officially marked as deprecated in Go 1.19. Client: go
-rw-r--r--lib/go/thrift/deserializer.go37
-rw-r--r--lib/go/thrift/framed_transport.go10
-rw-r--r--lib/go/thrift/header_transport.go12
-rw-r--r--lib/go/thrift/http_client.go3
-rw-r--r--lib/go/thrift/http_transport.go15
-rw-r--r--lib/go/thrift/pool.go69
-rw-r--r--lib/go/thrift/pool_test.go (renamed from lib/go/thrift/buf_pool.go)51
-rw-r--r--lib/go/thrift/protocol_test.go6
-rw-r--r--lib/go/thrift/serializer.go37
-rw-r--r--lib/go/thrift/simple_server_test.go9
10 files changed, 149 insertions, 100 deletions
diff --git a/lib/go/thrift/deserializer.go b/lib/go/thrift/deserializer.go
index 2f2468b29..0c68d6b5b 100644
--- a/lib/go/thrift/deserializer.go
+++ b/lib/go/thrift/deserializer.go
@@ -21,7 +21,6 @@ package thrift
import (
"context"
- "sync"
)
type TDeserializer struct {
@@ -81,7 +80,7 @@ func (t *TDeserializer) Read(ctx context.Context, msg TStruct, b []byte) (err er
// It must be initialized with either NewTDeserializerPool or
// NewTDeserializerPoolSizeFactory.
type TDeserializerPool struct {
- pool sync.Pool
+ pool *pool[TDeserializer]
}
// NewTDeserializerPool creates a new TDeserializerPool.
@@ -89,11 +88,7 @@ type TDeserializerPool struct {
// NewTDeserializer can be used as the arg here.
func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
return &TDeserializerPool{
- pool: sync.Pool{
- New: func() interface{} {
- return f()
- },
- },
+ pool: newPool(f, nil),
}
}
@@ -104,28 +99,26 @@ func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
// larger than that. It just dictates the initial size.
func NewTDeserializerPoolSizeFactory(size int, factory TProtocolFactory) *TDeserializerPool {
return &TDeserializerPool{
- pool: sync.Pool{
- New: func() interface{} {
- transport := NewTMemoryBufferLen(size)
- protocol := factory.GetProtocol(transport)
-
- return &TDeserializer{
- Transport: transport,
- Protocol: protocol,
- }
- },
- },
+ pool: newPool(func() *TDeserializer {
+ transport := NewTMemoryBufferLen(size)
+ protocol := factory.GetProtocol(transport)
+
+ return &TDeserializer{
+ Transport: transport,
+ Protocol: protocol,
+ }
+ }, nil),
}
}
func (t *TDeserializerPool) ReadString(ctx context.Context, msg TStruct, s string) error {
- d := t.pool.Get().(*TDeserializer)
- defer t.pool.Put(d)
+ d := t.pool.get()
+ defer t.pool.put(&d)
return d.ReadString(ctx, msg, s)
}
func (t *TDeserializerPool) Read(ctx context.Context, msg TStruct, b []byte) error {
- d := t.pool.Get().(*TDeserializer)
- defer t.pool.Put(d)
+ d := t.pool.get()
+ defer t.pool.put(&d)
return d.Read(ctx, msg, b)
}
diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go
index c8bd35e32..e3c323afc 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -133,7 +133,7 @@ func (p *TFramedTransport) Read(buf []byte) (read int, err error) {
// Make sure we return the read buffer back to pool
// after we finished reading from it.
if p.readBuf != nil && p.readBuf.Len() == 0 {
- returnBufToPool(&p.readBuf)
+ bufPool.put(&p.readBuf)
}
}()
@@ -175,7 +175,7 @@ func (p *TFramedTransport) ReadByte() (c byte, err error) {
func (p *TFramedTransport) ensureWriteBufferBeforeWrite() {
if p.writeBuf == nil {
- p.writeBuf = getBufFromPool()
+ p.writeBuf = bufPool.get()
}
}
@@ -196,7 +196,7 @@ func (p *TFramedTransport) WriteString(s string) (n int, err error) {
}
func (p *TFramedTransport) Flush(ctx context.Context) error {
- defer returnBufToPool(&p.writeBuf)
+ defer bufPool.put(&p.writeBuf)
size := p.writeBuf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
@@ -215,9 +215,9 @@ func (p *TFramedTransport) Flush(ctx context.Context) error {
func (p *TFramedTransport) readFrame() error {
if p.readBuf != nil {
- returnBufToPool(&p.readBuf)
+ bufPool.put(&p.readBuf)
}
- p.readBuf = getBufFromPool()
+ p.readBuf = bufPool.get()
buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index 5ec045482..3aea5a988 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -370,7 +370,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
// Read the frame fully into frameBuffer.
if t.frameBuffer == nil {
- t.frameBuffer = getBufFromPool()
+ t.frameBuffer = bufPool.get()
}
_, err = io.CopyN(t.frameBuffer, t.reader, int64(frameSize))
if err != nil {
@@ -407,7 +407,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
// It closes frameReader, and also resets frame related states.
func (t *THeaderTransport) endOfFrame() error {
defer func() {
- returnBufToPool(&t.frameBuffer)
+ bufPool.put(&t.frameBuffer)
t.frameReader = nil
}()
return t.frameReader.Close()
@@ -572,7 +572,7 @@ func (t *THeaderTransport) Read(p []byte) (read int, err error) {
// You need to call Flush to actually write them to the transport.
func (t *THeaderTransport) Write(p []byte) (int, error) {
if t.writeBuffer == nil {
- t.writeBuffer = getBufFromPool()
+ t.writeBuffer = bufPool.get()
}
return t.writeBuffer.Write(p)
}
@@ -583,7 +583,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
return nil
}
- defer returnBufToPool(&t.writeBuffer)
+ defer bufPool.put(&t.writeBuffer)
switch t.clientType {
default:
@@ -633,8 +633,8 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
}
}
- payload := getBufFromPool()
- defer returnBufToPool(&payload)
+ payload := bufPool.get()
+ defer bufPool.put(&payload)
meta := headerMeta{
MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask,
SequenceID: t.SequenceID,
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go
index ce62c96a2..a0f206653 100644
--- a/lib/go/thrift/http_client.go
+++ b/lib/go/thrift/http_client.go
@@ -24,7 +24,6 @@ import (
"context"
"errors"
"io"
- "io/ioutil"
"net/http"
"net/url"
"strconv"
@@ -136,7 +135,7 @@ func (p *THttpClient) closeResponse() error {
// reused. Errors are being ignored here because if the connection is invalid
// and this fails for some reason, the Close() method will do any remaining
// cleanup.
- io.Copy(ioutil.Discard, p.response.Body)
+ io.Copy(io.Discard, p.response.Body)
err = p.response.Body.Close()
}
diff --git a/lib/go/thrift/http_transport.go b/lib/go/thrift/http_transport.go
index bc6922762..c84aba953 100644
--- a/lib/go/thrift/http_transport.go
+++ b/lib/go/thrift/http_transport.go
@@ -24,7 +24,6 @@ import (
"io"
"net/http"
"strings"
- "sync"
)
// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
@@ -41,11 +40,9 @@ func NewThriftHandlerFunc(processor TProcessor,
// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
- sp := &sync.Pool{
- New: func() interface{} {
- return gzip.NewWriter(nil)
- },
- }
+ sp := newPool(func() *gzip.Writer {
+ return gzip.NewWriter(nil)
+ }, nil)
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
@@ -53,11 +50,11 @@ func gz(handler http.HandlerFunc) http.HandlerFunc {
return
}
w.Header().Set("Content-Encoding", "gzip")
- gz := sp.Get().(*gzip.Writer)
+ gz := sp.get()
gz.Reset(w)
defer func() {
- _ = gz.Close()
- sp.Put(gz)
+ gz.Close()
+ sp.put(&gz)
}()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
diff --git a/lib/go/thrift/pool.go b/lib/go/thrift/pool.go
new file mode 100644
index 000000000..1d623d422
--- /dev/null
+++ b/lib/go/thrift/pool.go
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "bytes"
+ "sync"
+)
+
+// pool is a generic sync.Pool wrapper with bells and whistles.
+type pool[T any] struct {
+ pool sync.Pool
+ reset func(*T)
+}
+
+// newPool creates a new pool.
+//
+// Both generate and reset are optional.
+// Default generate is just new(T),
+// When reset is nil we don't do any additional resetting when calling get.
+func newPool[T any](generate func() *T, reset func(*T)) *pool[T] {
+ if generate == nil {
+ generate = func() *T {
+ return new(T)
+ }
+ }
+ return &pool[T]{
+ pool: sync.Pool{
+ New: func() interface{} {
+ return generate()
+ },
+ },
+ reset: reset,
+ }
+}
+
+func (p *pool[T]) get() *T {
+ r := p.pool.Get().(*T)
+ if p.reset != nil {
+ p.reset(r)
+ }
+ return r
+}
+
+func (p *pool[T]) put(r **T) {
+ p.pool.Put(*r)
+ *r = nil
+}
+
+var bufPool = newPool(nil, func(buf *bytes.Buffer) {
+ buf.Reset()
+})
diff --git a/lib/go/thrift/buf_pool.go b/lib/go/thrift/pool_test.go
index 9708ea0e1..c717e1d6e 100644
--- a/lib/go/thrift/buf_pool.go
+++ b/lib/go/thrift/pool_test.go
@@ -20,33 +20,32 @@
package thrift
import (
- "bytes"
- "sync"
+ "testing"
+ "testing/quick"
)
-var bufPool = sync.Pool{
- New: func() interface{} {
- return new(bytes.Buffer)
- },
-}
-
-// getBufFromPool gets a buffer out of the pool and guarantees that it's reset
-// before return.
-func getBufFromPool() *bytes.Buffer {
- buf := bufPool.Get().(*bytes.Buffer)
- buf.Reset()
- return buf
-}
+type poolTest int
-// returnBufToPool returns a buffer to the pool, and sets it to nil to avoid
-// accidental usage after it's returned.
-//
-// You usually want to use it this way:
-//
-// buf := getBufFromPool()
-// defer returnBufToPool(&buf)
-// // use buf
-func returnBufToPool(buf **bytes.Buffer) {
- bufPool.Put(*buf)
- *buf = nil
+func TestPoolReset(t *testing.T) {
+ p := newPool(nil, func(elem *poolTest) {
+ *elem = 0
+ })
+ f := func(i int) (passed bool) {
+ pt := p.get()
+ defer func() {
+ p.put(&pt)
+ if pt != nil {
+ t.Errorf("Expected pt to be nil after put, got %#v", pt)
+ passed = false
+ }
+ }()
+ if *pt != 0 {
+ t.Errorf("Expected *pt to be reset to 0 after get, got %d", *pt)
+ }
+ *pt = poolTest(i)
+ return !t.Failed()
+ }
+ if err := quick.Check(f, nil); err != nil {
+ t.Error(err)
+ }
}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index caac78e99..d66dc65c2 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -22,7 +22,7 @@ package thrift
import (
"bytes"
"context"
- "io/ioutil"
+ "io"
"math"
"net"
"net/http"
@@ -60,7 +60,7 @@ type HTTPEchoServer struct{}
type HTTPHeaderEchoServer struct{}
func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- buf, err := ioutil.ReadAll(req.Body)
+ buf, err := io.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
@@ -71,7 +71,7 @@ func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- buf, err := ioutil.ReadAll(req.Body)
+ buf, err := io.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
diff --git a/lib/go/thrift/serializer.go b/lib/go/thrift/serializer.go
index f4d920186..53a674e7b 100644
--- a/lib/go/thrift/serializer.go
+++ b/lib/go/thrift/serializer.go
@@ -21,7 +21,6 @@ package thrift
import (
"context"
- "sync"
)
type TSerializer struct {
@@ -92,7 +91,7 @@ func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err err
// It must be initialized with either NewTSerializerPool or
// NewTSerializerPoolSizeFactory.
type TSerializerPool struct {
- pool sync.Pool
+ pool *pool[TSerializer]
}
// NewTSerializerPool creates a new TSerializerPool.
@@ -100,11 +99,7 @@ type TSerializerPool struct {
// NewTSerializer can be used as the arg here.
func NewTSerializerPool(f func() *TSerializer) *TSerializerPool {
return &TSerializerPool{
- pool: sync.Pool{
- New: func() interface{} {
- return f()
- },
- },
+ pool: newPool(f, nil),
}
}
@@ -115,28 +110,26 @@ func NewTSerializerPool(f func() *TSerializer) *TSerializerPool {
// larger than that. It just dictates the initial size.
func NewTSerializerPoolSizeFactory(size int, factory TProtocolFactory) *TSerializerPool {
return &TSerializerPool{
- pool: sync.Pool{
- New: func() interface{} {
- transport := NewTMemoryBufferLen(size)
- protocol := factory.GetProtocol(transport)
-
- return &TSerializer{
- Transport: transport,
- Protocol: protocol,
- }
- },
- },
+ pool: newPool(func() *TSerializer {
+ transport := NewTMemoryBufferLen(size)
+ protocol := factory.GetProtocol(transport)
+
+ return &TSerializer{
+ Transport: transport,
+ Protocol: protocol,
+ }
+ }, nil),
}
}
func (t *TSerializerPool) WriteString(ctx context.Context, msg TStruct) (string, error) {
- s := t.pool.Get().(*TSerializer)
- defer t.pool.Put(s)
+ s := t.pool.get()
+ defer t.pool.put(&s)
return s.WriteString(ctx, msg)
}
func (t *TSerializerPool) Write(ctx context.Context, msg TStruct) ([]byte, error) {
- s := t.pool.Get().(*TSerializer)
- defer t.pool.Put(s)
+ s := t.pool.get()
+ defer t.pool.put(&s)
return s.Write(ctx, msg)
}
diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go
index e0cf151b9..f3a59ee18 100644
--- a/lib/go/thrift/simple_server_test.go
+++ b/lib/go/thrift/simple_server_test.go
@@ -201,11 +201,11 @@ func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) {
netConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil || netConn == nil {
- t.Fatal("error when dial server")
+ t.Fatalf("error when dial server: %v", err)
}
time.Sleep(networkWaitDuration)
- serverStopTimeout := 50 * time.Millisecond
+ const serverStopTimeout = 50 * time.Millisecond
backupServerStopTimeout := ServerStopTimeout
t.Cleanup(func() {
ServerStopTimeout = backupServerStopTimeout
@@ -213,13 +213,12 @@ func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) {
ServerStopTimeout = serverStopTimeout
st := time.Now()
- err = serv.Stop()
- if err != nil {
+ if err := serv.Stop(); err != nil {
t.Errorf("error when stop server:%v", err)
}
if elapsed := time.Since(st); elapsed < serverStopTimeout {
- t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed)
+ t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", serverStopTimeout, elapsed)
}
}