summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2021-12-16 14:44:47 -0800
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2021-12-17 10:24:19 -0800
commitd582a861426c43c869e71d8d6ce598a33cbab316 (patch)
tree008a7d7b357761f1d8c19a3913cae16029ea3e69
parentb724787d373de99fee2222ab0eb2e052f8c8d3ed (diff)
downloadthrift-d582a861426c43c869e71d8d6ce598a33cbab316.tar.gz
THRIFT-5490: Use pooled buffer for THeaderTransport
Client: go Instead of binding 2 buffers (read/write) to each THeaderTransport, grab one from the pool to be used for the whole read/write, and return it back to the pool after the read/write is done. This would help reduce the memory footprint from idle connections.
-rw-r--r--lib/go/thrift/buf_pool.go52
-rw-r--r--lib/go/thrift/header_transport.go42
-rw-r--r--lib/go/thrift/header_transport_test.go59
3 files changed, 133 insertions, 20 deletions
diff --git a/lib/go/thrift/buf_pool.go b/lib/go/thrift/buf_pool.go
new file mode 100644
index 000000000..9708ea0e1
--- /dev/null
+++ b/lib/go/thrift/buf_pool.go
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "bytes"
+ "sync"
+)
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ return new(bytes.Buffer)
+ },
+}
+
+// getBufFromPool gets a buffer out of the pool and guarantees that it's reset
+// before return.
+func getBufFromPool() *bytes.Buffer {
+ buf := bufPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ return buf
+}
+
+// returnBufToPool returns a buffer to the pool, and sets it to nil to avoid
+// accidental usage after it's returned.
+//
+// You usually want to use it this way:
+//
+// buf := getBufFromPool()
+// defer returnBufToPool(&buf)
+// // use buf
+func returnBufToPool(buf **bytes.Buffer) {
+ bufPool.Put(*buf)
+ *buf = nil
+}
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index f5736df42..5ec045482 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -28,7 +28,6 @@ import (
"errors"
"fmt"
"io"
- "io/ioutil"
)
// Size in bytes for 32-bit ints.
@@ -253,14 +252,14 @@ type THeaderTransport struct {
// Reading related variables.
reader *bufio.Reader
// When frame is detected, we read the frame fully into frameBuffer.
- frameBuffer bytes.Buffer
+ frameBuffer *bytes.Buffer
// When it's non-nil, Read should read from frameReader instead of
// reader, and EOF error indicates end of frame instead of end of all
// transport.
frameReader io.ReadCloser
// Writing related variables
- writeBuffer bytes.Buffer
+ writeBuffer *bytes.Buffer
writeTransforms []THeaderTransformID
clientType clientType
@@ -370,11 +369,14 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
t.reader.Discard(size32)
// Read the frame fully into frameBuffer.
- _, err = io.CopyN(&t.frameBuffer, t.reader, int64(frameSize))
+ if t.frameBuffer == nil {
+ t.frameBuffer = getBufFromPool()
+ }
+ _, err = io.CopyN(t.frameBuffer, t.reader, int64(frameSize))
if err != nil {
return err
}
- t.frameReader = ioutil.NopCloser(&t.frameBuffer)
+ t.frameReader = io.NopCloser(t.frameBuffer)
// Peek and handle the next 32 bits.
buf = t.frameBuffer.Bytes()[:size32]
@@ -405,7 +407,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
// It closes frameReader, and also resets frame related states.
func (t *THeaderTransport) endOfFrame() error {
defer func() {
- t.frameBuffer.Reset()
+ returnBufToPool(&t.frameBuffer)
t.frameReader = nil
}()
return t.frameReader.Close()
@@ -418,7 +420,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
var err error
var meta headerMeta
- if err = binary.Read(&t.frameBuffer, binary.BigEndian, &meta); err != nil {
+ if err = binary.Read(t.frameBuffer, binary.BigEndian, &meta); err != nil {
return err
}
frameSize -= headerMetaSize
@@ -432,7 +434,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
)
}
headerBuf := NewTMemoryBuffer()
- _, err = io.CopyN(headerBuf, &t.frameBuffer, headerLength)
+ _, err = io.CopyN(headerBuf, t.frameBuffer, headerLength)
if err != nil {
return err
}
@@ -454,7 +456,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
}
if transformCount > 0 {
reader := NewTransformReaderWithCapacity(
- &t.frameBuffer,
+ t.frameBuffer,
int(transformCount),
)
t.frameReader = reader
@@ -569,16 +571,19 @@ func (t *THeaderTransport) Read(p []byte) (read int, err error) {
//
// You need to call Flush to actually write them to the transport.
func (t *THeaderTransport) Write(p []byte) (int, error) {
+ if t.writeBuffer == nil {
+ t.writeBuffer = getBufFromPool()
+ }
return t.writeBuffer.Write(p)
}
// Flush writes the appropriate header and the write buffer to the underlying transport.
func (t *THeaderTransport) Flush(ctx context.Context) error {
- if t.writeBuffer.Len() == 0 {
+ if t.writeBuffer == nil || t.writeBuffer.Len() == 0 {
return nil
}
- defer t.writeBuffer.Reset()
+ defer returnBufToPool(&t.writeBuffer)
switch t.clientType {
default:
@@ -628,24 +633,25 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
}
}
- var payload bytes.Buffer
+ payload := getBufFromPool()
+ defer returnBufToPool(&payload)
meta := headerMeta{
MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask,
SequenceID: t.SequenceID,
HeaderLength: uint16(headers.Len() / 4),
}
- if err := binary.Write(&payload, binary.BigEndian, meta); err != nil {
+ if err := binary.Write(payload, binary.BigEndian, meta); err != nil {
return NewTTransportExceptionFromError(err)
}
- if _, err := io.Copy(&payload, headers); err != nil {
+ if _, err := io.Copy(payload, headers); err != nil {
return NewTTransportExceptionFromError(err)
}
- writer, err := NewTransformWriter(&payload, t.writeTransforms)
+ writer, err := NewTransformWriter(payload, t.writeTransforms)
if err != nil {
return NewTTransportExceptionFromError(err)
}
- if _, err := io.Copy(writer, &t.writeBuffer); err != nil {
+ if _, err := io.Copy(writer, t.writeBuffer); err != nil {
return NewTTransportExceptionFromError(err)
}
if err := writer.Close(); err != nil {
@@ -659,7 +665,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
return NewTTransportExceptionFromError(err)
}
// Then write the payload
- if _, err := io.Copy(t.transport, &payload); err != nil {
+ if _, err := io.Copy(t.transport, payload); err != nil {
return NewTTransportExceptionFromError(err)
}
@@ -671,7 +677,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
}
fallthrough
case clientUnframedBinary, clientUnframedCompact:
- if _, err := io.Copy(t.transport, &t.writeBuffer); err != nil {
+ if _, err := io.Copy(t.transport, t.writeBuffer); err != nil {
return NewTTransportExceptionFromError(err)
}
}
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 65e69ee5a..25ba8d3b1 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -23,7 +23,6 @@ import (
"context"
"fmt"
"io"
- "io/ioutil"
"strings"
"testing"
"testing/quick"
@@ -87,7 +86,7 @@ func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocol
if err := reader.ReadFrame(context.Background()); err != nil {
t.Errorf("reader.ReadFrame returned error: %v", err)
}
- read, err := ioutil.ReadAll(reader)
+ read, err := io.ReadAll(reader)
if err != nil {
t.Errorf("Read returned error: %v", err)
}
@@ -305,3 +304,59 @@ func TestSetTHeaderTransportProtocolID(t *testing.T) {
t.Errorf("Expected protocol id %v, got %v", expected, actual)
}
}
+
+func TestTHeaderTransportReuseTransport(t *testing.T) {
+ const (
+ content = "Hello, world!"
+ n = 10
+ )
+ trans := NewTMemoryBuffer()
+ reader := NewTHeaderTransport(trans)
+ writer := NewTHeaderTransport(trans)
+
+ t.Run("pair", func(t *testing.T) {
+ for i := 0; i < n; i++ {
+ // write
+ if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+ t.Fatalf("Failed to write on #%d: %v", i, err)
+ }
+ if err := writer.Flush(context.Background()); err != nil {
+ t.Fatalf("Failed to flush on #%d: %v", i, err)
+ }
+
+ // read
+ read, err := io.ReadAll(reader)
+ if err != nil {
+ t.Errorf("Failed to read on #%d: %v", i, err)
+ }
+ if string(read) != content {
+ t.Errorf("Read #%d: want %q, got %q", i, content, read)
+ }
+ }
+ })
+
+ t.Run("batched", func(t *testing.T) {
+ // write
+ for i := 0; i < n; i++ {
+ if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+ t.Fatalf("Failed to write on #%d: %v", i, err)
+ }
+ if err := writer.Flush(context.Background()); err != nil {
+ t.Fatalf("Failed to flush on #%d: %v", i, err)
+ }
+ }
+
+ // read
+ for i := 0; i < n; i++ {
+ buf := make([]byte, len(content))
+ n, err := reader.Read(buf)
+ if err != nil {
+ t.Errorf("Failed to read on #%d: %v", i, err)
+ }
+ read := string(buf[:n])
+ if string(read) != content {
+ t.Errorf("Read #%d: want %q, got %q", i, content, read)
+ }
+ }
+ })
+}