diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2021-12-17 10:39:07 -0800 |
---|---|---|
committer | Yuxuan 'fishy' Wang <fishywang@gmail.com> | 2022-01-05 14:21:58 -0800 |
commit | 999e6e3bce217acb35b44440fd656cf169d47ed8 (patch) | |
tree | dfc4563ceda1b9cccb77a7d4ef71f4ea4c055620 | |
parent | d582a861426c43c869e71d8d6ce598a33cbab316 (diff) | |
download | thrift-999e6e3bce217acb35b44440fd656cf169d47ed8.tar.gz |
THRIFT-5490: Use pooled buffer for TFramedTransport
Client: go
Follow up on d582a8614, do the same thing on TFramedTransport.
Also update the test on the implementation of THeaderTransport to make
sure that small reads are not broken.
-rw-r--r-- | lib/go/thrift/framed_transport.go | 61 | ||||
-rw-r--r-- | lib/go/thrift/framed_transport_test.go | 73 | ||||
-rw-r--r-- | lib/go/thrift/header_transport_test.go | 39 |
3 files changed, 150 insertions, 23 deletions
diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go index 2156dd76f..c8bd35e32 100644 --- a/lib/go/thrift/framed_transport.go +++ b/lib/go/thrift/framed_transport.go @@ -36,10 +36,10 @@ type TFramedTransport struct { cfg *TConfiguration - writeBuf bytes.Buffer + writeBuf *bytes.Buffer reader *bufio.Reader - readBuf bytes.Buffer + readBuf *bytes.Buffer buffer [4]byte } @@ -129,18 +129,29 @@ func (p *TFramedTransport) Close() error { } func (p *TFramedTransport) Read(buf []byte) (read int, err error) { - read, err = p.readBuf.Read(buf) - if err != io.EOF { - return - } + defer func() { + // Make sure we return the read buffer back to pool + // after we finished reading from it. + if p.readBuf != nil && p.readBuf.Len() == 0 { + returnBufToPool(&p.readBuf) + } + }() + + if p.readBuf != nil { - // For bytes.Buffer.Read, EOF would only happen when read is zero, - // but still, do a sanity check, - // in case that behavior is changed in a future version of go stdlib. - // When that happens, just return nil error, - // and let the caller call Read again to read the next frame. - if read > 0 { - return read, nil + read, err = p.readBuf.Read(buf) + if err != io.EOF { + return + } + + // For bytes.Buffer.Read, EOF would only happen when read is zero, + // but still, do a sanity check, + // in case that behavior is changed in a future version of go stdlib. + // When that happens, just return nil error, + // and let the caller call Read again to read the next frame. + if read > 0 { + return read, nil + } } // Reaching here means that the last Read finished the last frame, @@ -162,31 +173,39 @@ func (p *TFramedTransport) ReadByte() (c byte, err error) { return } +func (p *TFramedTransport) ensureWriteBufferBeforeWrite() { + if p.writeBuf == nil { + p.writeBuf = getBufFromPool() + } +} + func (p *TFramedTransport) Write(buf []byte) (int, error) { + p.ensureWriteBufferBeforeWrite() n, err := p.writeBuf.Write(buf) return n, NewTTransportExceptionFromError(err) } func (p *TFramedTransport) WriteByte(c byte) error { + p.ensureWriteBufferBeforeWrite() return p.writeBuf.WriteByte(c) } func (p *TFramedTransport) WriteString(s string) (n int, err error) { + p.ensureWriteBufferBeforeWrite() return p.writeBuf.WriteString(s) } func (p *TFramedTransport) Flush(ctx context.Context) error { + defer returnBufToPool(&p.writeBuf) size := p.writeBuf.Len() buf := p.buffer[:4] binary.BigEndian.PutUint32(buf, uint32(size)) _, err := p.transport.Write(buf) if err != nil { - p.writeBuf.Reset() return NewTTransportExceptionFromError(err) } if size > 0 { - if _, err := io.Copy(p.transport, &p.writeBuf); err != nil { - p.writeBuf.Reset() + if _, err := io.Copy(p.transport, p.writeBuf); err != nil { return NewTTransportExceptionFromError(err) } } @@ -195,6 +214,11 @@ func (p *TFramedTransport) Flush(ctx context.Context) error { } func (p *TFramedTransport) readFrame() error { + if p.readBuf != nil { + returnBufToPool(&p.readBuf) + } + p.readBuf = getBufFromPool() + buf := p.buffer[:4] if _, err := io.ReadFull(p.reader, buf); err != nil { return err @@ -203,11 +227,14 @@ func (p *TFramedTransport) readFrame() error { if size > uint32(p.cfg.GetMaxFrameSize()) { return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size)) } - _, err := io.CopyN(&p.readBuf, p.reader, int64(size)) + _, err := io.CopyN(p.readBuf, p.reader, int64(size)) return NewTTransportExceptionFromError(err) } func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) { + if p.readBuf == nil { + return 0 + } return uint64(p.readBuf.Len()) } diff --git a/lib/go/thrift/framed_transport_test.go b/lib/go/thrift/framed_transport_test.go index 8f683ef30..4e7d9cae0 100644 --- a/lib/go/thrift/framed_transport_test.go +++ b/lib/go/thrift/framed_transport_test.go @@ -20,6 +20,9 @@ package thrift import ( + "context" + "io" + "strings" "testing" ) @@ -27,3 +30,73 @@ func TestFramedTransport(t *testing.T) { trans := NewTFramedTransport(NewTMemoryBuffer()) TransportTest(t, trans, trans) } + +func TestTFramedTransportReuseTransport(t *testing.T) { + const ( + content = "Hello, world!" + n = 10 + ) + trans := NewTMemoryBuffer() + reader := NewTFramedTransport(trans) + writer := NewTFramedTransport(trans) + + t.Run("pair", func(t *testing.T) { + for i := 0; i < n; i++ { + // write + if _, err := io.Copy(writer, strings.NewReader(content)); err != nil { + t.Fatalf("Failed to write on #%d: %v", i, err) + } + if err := writer.Flush(context.Background()); err != nil { + t.Fatalf("Failed to flush on #%d: %v", i, err) + } + + // read + read, err := io.ReadAll(oneAtATimeReader{reader}) + if err != nil { + t.Errorf("Failed to read on #%d: %v", i, err) + } + if string(read) != content { + t.Errorf("Read #%d: want %q, got %q", i, content, read) + } + } + }) + + t.Run("batched", func(t *testing.T) { + // write + for i := 0; i < n; i++ { + if _, err := io.Copy(writer, strings.NewReader(content)); err != nil { + t.Fatalf("Failed to write on #%d: %v", i, err) + } + if err := writer.Flush(context.Background()); err != nil { + t.Fatalf("Failed to flush on #%d: %v", i, err) + } + } + + // read + for i := 0; i < n; i++ { + const ( + size = len(content) + ) + var buf []byte + var err error + if i%2 == 0 { + // on even calls, use oneAtATimeReader to make + // sure that small reads are fine + buf, err = io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size))) + } else { + // on odd calls, make sure that we don't read + // more than written per frame + buf = make([]byte, size*2) + var n int + n, err = reader.Read(buf) + buf = buf[:n] + } + if err != nil { + t.Errorf("Failed to read on #%d: %v", i, err) + } + if string(buf) != content { + t.Errorf("Read #%d: want %q, got %q", i, content, buf) + } + } + }) +} diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go index 25ba8d3b1..44d0284db 100644 --- a/lib/go/thrift/header_transport_test.go +++ b/lib/go/thrift/header_transport_test.go @@ -325,7 +325,7 @@ func TestTHeaderTransportReuseTransport(t *testing.T) { } // read - read, err := io.ReadAll(reader) + read, err := io.ReadAll(oneAtATimeReader{reader}) if err != nil { t.Errorf("Failed to read on #%d: %v", i, err) } @@ -348,15 +348,42 @@ func TestTHeaderTransportReuseTransport(t *testing.T) { // read for i := 0; i < n; i++ { - buf := make([]byte, len(content)) - n, err := reader.Read(buf) + const ( + size = len(content) + ) + var buf []byte + var err error + if i%2 == 0 { + // on even calls, use oneAtATimeReader to make + // sure that small reads are fine + buf, err = io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size))) + } else { + // on odd calls, make sure that we don't read + // more than written per frame + buf = make([]byte, size*2) + var n int + n, err = reader.Read(buf) + buf = buf[:n] + } if err != nil { t.Errorf("Failed to read on #%d: %v", i, err) } - read := string(buf[:n]) - if string(read) != content { - t.Errorf("Read #%d: want %q, got %q", i, content, read) + if string(buf) != content { + t.Errorf("Read #%d: want %q, got %q", i, content, buf) } } }) } + +type oneAtATimeReader struct { + io.Reader +} + +// oneAtATimeReader forces every Read call to only read 1 byte out, +// thus forces the underlying reader's Read to be called multiple times. +func (o oneAtATimeReader) Read(buf []byte) (int, error) { + if len(buf) < 1 { + return o.Reader.Read(buf) + } + return o.Reader.Read(buf[:1]) +} |