summaryrefslogtreecommitdiff
path: root/lib/go/thrift
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2021-12-17 10:39:07 -0800
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2022-01-05 14:21:58 -0800
commit999e6e3bce217acb35b44440fd656cf169d47ed8 (patch)
treedfc4563ceda1b9cccb77a7d4ef71f4ea4c055620 /lib/go/thrift
parentd582a861426c43c869e71d8d6ce598a33cbab316 (diff)
downloadthrift-999e6e3bce217acb35b44440fd656cf169d47ed8.tar.gz
THRIFT-5490: Use pooled buffer for TFramedTransport
Client: go Follow up on d582a8614, do the same thing on TFramedTransport. Also update the test on the implementation of THeaderTransport to make sure that small reads are not broken.
Diffstat (limited to 'lib/go/thrift')
-rw-r--r--lib/go/thrift/framed_transport.go61
-rw-r--r--lib/go/thrift/framed_transport_test.go73
-rw-r--r--lib/go/thrift/header_transport_test.go39
3 files changed, 150 insertions, 23 deletions
diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go
index 2156dd76f..c8bd35e32 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -36,10 +36,10 @@ type TFramedTransport struct {
cfg *TConfiguration
- writeBuf bytes.Buffer
+ writeBuf *bytes.Buffer
reader *bufio.Reader
- readBuf bytes.Buffer
+ readBuf *bytes.Buffer
buffer [4]byte
}
@@ -129,18 +129,29 @@ func (p *TFramedTransport) Close() error {
}
func (p *TFramedTransport) Read(buf []byte) (read int, err error) {
- read, err = p.readBuf.Read(buf)
- if err != io.EOF {
- return
- }
+ defer func() {
+ // Make sure we return the read buffer back to pool
+ // after we finished reading from it.
+ if p.readBuf != nil && p.readBuf.Len() == 0 {
+ returnBufToPool(&p.readBuf)
+ }
+ }()
+
+ if p.readBuf != nil {
- // For bytes.Buffer.Read, EOF would only happen when read is zero,
- // but still, do a sanity check,
- // in case that behavior is changed in a future version of go stdlib.
- // When that happens, just return nil error,
- // and let the caller call Read again to read the next frame.
- if read > 0 {
- return read, nil
+ read, err = p.readBuf.Read(buf)
+ if err != io.EOF {
+ return
+ }
+
+ // For bytes.Buffer.Read, EOF would only happen when read is zero,
+ // but still, do a sanity check,
+ // in case that behavior is changed in a future version of go stdlib.
+ // When that happens, just return nil error,
+ // and let the caller call Read again to read the next frame.
+ if read > 0 {
+ return read, nil
+ }
}
// Reaching here means that the last Read finished the last frame,
@@ -162,31 +173,39 @@ func (p *TFramedTransport) ReadByte() (c byte, err error) {
return
}
+func (p *TFramedTransport) ensureWriteBufferBeforeWrite() {
+ if p.writeBuf == nil {
+ p.writeBuf = getBufFromPool()
+ }
+}
+
func (p *TFramedTransport) Write(buf []byte) (int, error) {
+ p.ensureWriteBufferBeforeWrite()
n, err := p.writeBuf.Write(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) WriteByte(c byte) error {
+ p.ensureWriteBufferBeforeWrite()
return p.writeBuf.WriteByte(c)
}
func (p *TFramedTransport) WriteString(s string) (n int, err error) {
+ p.ensureWriteBufferBeforeWrite()
return p.writeBuf.WriteString(s)
}
func (p *TFramedTransport) Flush(ctx context.Context) error {
+ defer returnBufToPool(&p.writeBuf)
size := p.writeBuf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
_, err := p.transport.Write(buf)
if err != nil {
- p.writeBuf.Reset()
return NewTTransportExceptionFromError(err)
}
if size > 0 {
- if _, err := io.Copy(p.transport, &p.writeBuf); err != nil {
- p.writeBuf.Reset()
+ if _, err := io.Copy(p.transport, p.writeBuf); err != nil {
return NewTTransportExceptionFromError(err)
}
}
@@ -195,6 +214,11 @@ func (p *TFramedTransport) Flush(ctx context.Context) error {
}
func (p *TFramedTransport) readFrame() error {
+ if p.readBuf != nil {
+ returnBufToPool(&p.readBuf)
+ }
+ p.readBuf = getBufFromPool()
+
buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
return err
@@ -203,11 +227,14 @@ func (p *TFramedTransport) readFrame() error {
if size > uint32(p.cfg.GetMaxFrameSize()) {
return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
}
- _, err := io.CopyN(&p.readBuf, p.reader, int64(size))
+ _, err := io.CopyN(p.readBuf, p.reader, int64(size))
return NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
+ if p.readBuf == nil {
+ return 0
+ }
return uint64(p.readBuf.Len())
}
diff --git a/lib/go/thrift/framed_transport_test.go b/lib/go/thrift/framed_transport_test.go
index 8f683ef30..4e7d9cae0 100644
--- a/lib/go/thrift/framed_transport_test.go
+++ b/lib/go/thrift/framed_transport_test.go
@@ -20,6 +20,9 @@
package thrift
import (
+ "context"
+ "io"
+ "strings"
"testing"
)
@@ -27,3 +30,73 @@ func TestFramedTransport(t *testing.T) {
trans := NewTFramedTransport(NewTMemoryBuffer())
TransportTest(t, trans, trans)
}
+
+func TestTFramedTransportReuseTransport(t *testing.T) {
+ const (
+ content = "Hello, world!"
+ n = 10
+ )
+ trans := NewTMemoryBuffer()
+ reader := NewTFramedTransport(trans)
+ writer := NewTFramedTransport(trans)
+
+ t.Run("pair", func(t *testing.T) {
+ for i := 0; i < n; i++ {
+ // write
+ if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+ t.Fatalf("Failed to write on #%d: %v", i, err)
+ }
+ if err := writer.Flush(context.Background()); err != nil {
+ t.Fatalf("Failed to flush on #%d: %v", i, err)
+ }
+
+ // read
+ read, err := io.ReadAll(oneAtATimeReader{reader})
+ if err != nil {
+ t.Errorf("Failed to read on #%d: %v", i, err)
+ }
+ if string(read) != content {
+ t.Errorf("Read #%d: want %q, got %q", i, content, read)
+ }
+ }
+ })
+
+ t.Run("batched", func(t *testing.T) {
+ // write
+ for i := 0; i < n; i++ {
+ if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+ t.Fatalf("Failed to write on #%d: %v", i, err)
+ }
+ if err := writer.Flush(context.Background()); err != nil {
+ t.Fatalf("Failed to flush on #%d: %v", i, err)
+ }
+ }
+
+ // read
+ for i := 0; i < n; i++ {
+ const (
+ size = len(content)
+ )
+ var buf []byte
+ var err error
+ if i%2 == 0 {
+ // on even calls, use oneAtATimeReader to make
+ // sure that small reads are fine
+ buf, err = io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size)))
+ } else {
+ // on odd calls, make sure that we don't read
+ // more than written per frame
+ buf = make([]byte, size*2)
+ var n int
+ n, err = reader.Read(buf)
+ buf = buf[:n]
+ }
+ if err != nil {
+ t.Errorf("Failed to read on #%d: %v", i, err)
+ }
+ if string(buf) != content {
+ t.Errorf("Read #%d: want %q, got %q", i, content, buf)
+ }
+ }
+ })
+}
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 25ba8d3b1..44d0284db 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -325,7 +325,7 @@ func TestTHeaderTransportReuseTransport(t *testing.T) {
}
// read
- read, err := io.ReadAll(reader)
+ read, err := io.ReadAll(oneAtATimeReader{reader})
if err != nil {
t.Errorf("Failed to read on #%d: %v", i, err)
}
@@ -348,15 +348,42 @@ func TestTHeaderTransportReuseTransport(t *testing.T) {
// read
for i := 0; i < n; i++ {
- buf := make([]byte, len(content))
- n, err := reader.Read(buf)
+ const (
+ size = len(content)
+ )
+ var buf []byte
+ var err error
+ if i%2 == 0 {
+ // on even calls, use oneAtATimeReader to make
+ // sure that small reads are fine
+ buf, err = io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size)))
+ } else {
+ // on odd calls, make sure that we don't read
+ // more than written per frame
+ buf = make([]byte, size*2)
+ var n int
+ n, err = reader.Read(buf)
+ buf = buf[:n]
+ }
if err != nil {
t.Errorf("Failed to read on #%d: %v", i, err)
}
- read := string(buf[:n])
- if string(read) != content {
- t.Errorf("Read #%d: want %q, got %q", i, content, read)
+ if string(buf) != content {
+ t.Errorf("Read #%d: want %q, got %q", i, content, buf)
}
}
})
}
+
+type oneAtATimeReader struct {
+ io.Reader
+}
+
+// oneAtATimeReader forces every Read call to only read 1 byte out,
+// thus forces the underlying reader's Read to be called multiple times.
+func (o oneAtATimeReader) Read(buf []byte) (int, error) {
+ if len(buf) < 1 {
+ return o.Reader.Read(buf)
+ }
+ return o.Reader.Read(buf[:1])
+}