summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2021-12-17 10:39:07 -0800
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2022-01-05 14:21:58 -0800
commit999e6e3bce217acb35b44440fd656cf169d47ed8 (patch)
treedfc4563ceda1b9cccb77a7d4ef71f4ea4c055620
parentd582a861426c43c869e71d8d6ce598a33cbab316 (diff)
downloadthrift-999e6e3bce217acb35b44440fd656cf169d47ed8.tar.gz
THRIFT-5490: Use pooled buffer for TFramedTransport
Client: go Follow up on d582a8614, do the same thing on TFramedTransport. Also update the test on the implementation of THeaderTransport to make sure that small reads are not broken.
-rw-r--r--lib/go/thrift/framed_transport.go61
-rw-r--r--lib/go/thrift/framed_transport_test.go73
-rw-r--r--lib/go/thrift/header_transport_test.go39
3 files changed, 150 insertions, 23 deletions
diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go
index 2156dd76f..c8bd35e32 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -36,10 +36,10 @@ type TFramedTransport struct {
cfg *TConfiguration
- writeBuf bytes.Buffer
+ writeBuf *bytes.Buffer
reader *bufio.Reader
- readBuf bytes.Buffer
+ readBuf *bytes.Buffer
buffer [4]byte
}
@@ -129,18 +129,29 @@ func (p *TFramedTransport) Close() error {
}
func (p *TFramedTransport) Read(buf []byte) (read int, err error) {
- read, err = p.readBuf.Read(buf)
- if err != io.EOF {
- return
- }
+ defer func() {
+ // Make sure we return the read buffer back to pool
+ // after we finished reading from it.
+ if p.readBuf != nil && p.readBuf.Len() == 0 {
+ returnBufToPool(&p.readBuf)
+ }
+ }()
+
+ if p.readBuf != nil {
- // For bytes.Buffer.Read, EOF would only happen when read is zero,
- // but still, do a sanity check,
- // in case that behavior is changed in a future version of go stdlib.
- // When that happens, just return nil error,
- // and let the caller call Read again to read the next frame.
- if read > 0 {
- return read, nil
+ read, err = p.readBuf.Read(buf)
+ if err != io.EOF {
+ return
+ }
+
+ // For bytes.Buffer.Read, EOF would only happen when read is zero,
+ // but still, do a sanity check,
+ // in case that behavior is changed in a future version of go stdlib.
+ // When that happens, just return nil error,
+ // and let the caller call Read again to read the next frame.
+ if read > 0 {
+ return read, nil
+ }
}
// Reaching here means that the last Read finished the last frame,
@@ -162,31 +173,39 @@ func (p *TFramedTransport) ReadByte() (c byte, err error) {
return
}
+func (p *TFramedTransport) ensureWriteBufferBeforeWrite() {
+ if p.writeBuf == nil {
+ p.writeBuf = getBufFromPool()
+ }
+}
+
func (p *TFramedTransport) Write(buf []byte) (int, error) {
+ p.ensureWriteBufferBeforeWrite()
n, err := p.writeBuf.Write(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) WriteByte(c byte) error {
+ p.ensureWriteBufferBeforeWrite()
return p.writeBuf.WriteByte(c)
}
func (p *TFramedTransport) WriteString(s string) (n int, err error) {
+ p.ensureWriteBufferBeforeWrite()
return p.writeBuf.WriteString(s)
}
func (p *TFramedTransport) Flush(ctx context.Context) error {
+ defer returnBufToPool(&p.writeBuf)
size := p.writeBuf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
_, err := p.transport.Write(buf)
if err != nil {
- p.writeBuf.Reset()
return NewTTransportExceptionFromError(err)
}
if size > 0 {
- if _, err := io.Copy(p.transport, &p.writeBuf); err != nil {
- p.writeBuf.Reset()
+ if _, err := io.Copy(p.transport, p.writeBuf); err != nil {
return NewTTransportExceptionFromError(err)
}
}
@@ -195,6 +214,11 @@ func (p *TFramedTransport) Flush(ctx context.Context) error {
}
func (p *TFramedTransport) readFrame() error {
+ if p.readBuf != nil {
+ returnBufToPool(&p.readBuf)
+ }
+ p.readBuf = getBufFromPool()
+
buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
return err
@@ -203,11 +227,14 @@ func (p *TFramedTransport) readFrame() error {
if size > uint32(p.cfg.GetMaxFrameSize()) {
return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
}
- _, err := io.CopyN(&p.readBuf, p.reader, int64(size))
+ _, err := io.CopyN(p.readBuf, p.reader, int64(size))
return NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
+ if p.readBuf == nil {
+ return 0
+ }
return uint64(p.readBuf.Len())
}
diff --git a/lib/go/thrift/framed_transport_test.go b/lib/go/thrift/framed_transport_test.go
index 8f683ef30..4e7d9cae0 100644
--- a/lib/go/thrift/framed_transport_test.go
+++ b/lib/go/thrift/framed_transport_test.go
@@ -20,6 +20,9 @@
package thrift
import (
+ "context"
+ "io"
+ "strings"
"testing"
)
@@ -27,3 +30,73 @@ func TestFramedTransport(t *testing.T) {
trans := NewTFramedTransport(NewTMemoryBuffer())
TransportTest(t, trans, trans)
}
+
+func TestTFramedTransportReuseTransport(t *testing.T) {
+ const (
+ content = "Hello, world!"
+ n = 10
+ )
+ trans := NewTMemoryBuffer()
+ reader := NewTFramedTransport(trans)
+ writer := NewTFramedTransport(trans)
+
+ t.Run("pair", func(t *testing.T) {
+ for i := 0; i < n; i++ {
+ // write
+ if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+ t.Fatalf("Failed to write on #%d: %v", i, err)
+ }
+ if err := writer.Flush(context.Background()); err != nil {
+ t.Fatalf("Failed to flush on #%d: %v", i, err)
+ }
+
+ // read
+ read, err := io.ReadAll(oneAtATimeReader{reader})
+ if err != nil {
+ t.Errorf("Failed to read on #%d: %v", i, err)
+ }
+ if string(read) != content {
+ t.Errorf("Read #%d: want %q, got %q", i, content, read)
+ }
+ }
+ })
+
+ t.Run("batched", func(t *testing.T) {
+ // write
+ for i := 0; i < n; i++ {
+ if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+ t.Fatalf("Failed to write on #%d: %v", i, err)
+ }
+ if err := writer.Flush(context.Background()); err != nil {
+ t.Fatalf("Failed to flush on #%d: %v", i, err)
+ }
+ }
+
+ // read
+ for i := 0; i < n; i++ {
+ const (
+ size = len(content)
+ )
+ var buf []byte
+ var err error
+ if i%2 == 0 {
+ // on even calls, use oneAtATimeReader to make
+ // sure that small reads are fine
+ buf, err = io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size)))
+ } else {
+ // on odd calls, make sure that we don't read
+ // more than written per frame
+ buf = make([]byte, size*2)
+ var n int
+ n, err = reader.Read(buf)
+ buf = buf[:n]
+ }
+ if err != nil {
+ t.Errorf("Failed to read on #%d: %v", i, err)
+ }
+ if string(buf) != content {
+ t.Errorf("Read #%d: want %q, got %q", i, content, buf)
+ }
+ }
+ })
+}
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 25ba8d3b1..44d0284db 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -325,7 +325,7 @@ func TestTHeaderTransportReuseTransport(t *testing.T) {
}
// read
- read, err := io.ReadAll(reader)
+ read, err := io.ReadAll(oneAtATimeReader{reader})
if err != nil {
t.Errorf("Failed to read on #%d: %v", i, err)
}
@@ -348,15 +348,42 @@ func TestTHeaderTransportReuseTransport(t *testing.T) {
// read
for i := 0; i < n; i++ {
- buf := make([]byte, len(content))
- n, err := reader.Read(buf)
+ const (
+ size = len(content)
+ )
+ var buf []byte
+ var err error
+ if i%2 == 0 {
+ // on even calls, use oneAtATimeReader to make
+ // sure that small reads are fine
+ buf, err = io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size)))
+ } else {
+ // on odd calls, make sure that we don't read
+ // more than written per frame
+ buf = make([]byte, size*2)
+ var n int
+ n, err = reader.Read(buf)
+ buf = buf[:n]
+ }
if err != nil {
t.Errorf("Failed to read on #%d: %v", i, err)
}
- read := string(buf[:n])
- if string(read) != content {
- t.Errorf("Read #%d: want %q, got %q", i, content, read)
+ if string(buf) != content {
+ t.Errorf("Read #%d: want %q, got %q", i, content, buf)
}
}
})
}
+
+type oneAtATimeReader struct {
+ io.Reader
+}
+
+// oneAtATimeReader forces every Read call to only read 1 byte out,
+// thus forces the underlying reader's Read to be called multiple times.
+func (o oneAtATimeReader) Read(buf []byte) (int, error) {
+ if len(buf) < 1 {
+ return o.Reader.Read(buf)
+ }
+ return o.Reader.Read(buf[:1])
+}