diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2021-12-16 14:44:47 -0800 |
---|---|---|
committer | Yuxuan 'fishy' Wang <fishywang@gmail.com> | 2021-12-17 10:24:19 -0800 |
commit | d582a861426c43c869e71d8d6ce598a33cbab316 (patch) | |
tree | 008a7d7b357761f1d8c19a3913cae16029ea3e69 | |
parent | b724787d373de99fee2222ab0eb2e052f8c8d3ed (diff) | |
download | thrift-d582a861426c43c869e71d8d6ce598a33cbab316.tar.gz |
THRIFT-5490: Use pooled buffer for THeaderTransport
Client: go
Instead of binding 2 buffers (read/write) to each THeaderTransport, grab
one from the pool to be used for the whole read/write, and return it
back to the pool after the read/write is done. This would help reduce
the memory footprint from idle connections.
-rw-r--r-- | lib/go/thrift/buf_pool.go | 52 | ||||
-rw-r--r-- | lib/go/thrift/header_transport.go | 42 | ||||
-rw-r--r-- | lib/go/thrift/header_transport_test.go | 59 |
3 files changed, 133 insertions, 20 deletions
diff --git a/lib/go/thrift/buf_pool.go b/lib/go/thrift/buf_pool.go new file mode 100644 index 000000000..9708ea0e1 --- /dev/null +++ b/lib/go/thrift/buf_pool.go @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "sync" +) + +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +// getBufFromPool gets a buffer out of the pool and guarantees that it's reset +// before return. +func getBufFromPool() *bytes.Buffer { + buf := bufPool.Get().(*bytes.Buffer) + buf.Reset() + return buf +} + +// returnBufToPool returns a buffer to the pool, and sets it to nil to avoid +// accidental usage after it's returned. +// +// You usually want to use it this way: +// +// buf := getBufFromPool() +// defer returnBufToPool(&buf) +// // use buf +func returnBufToPool(buf **bytes.Buffer) { + bufPool.Put(*buf) + *buf = nil +} diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go index f5736df42..5ec045482 100644 --- a/lib/go/thrift/header_transport.go +++ b/lib/go/thrift/header_transport.go @@ -28,7 +28,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" ) // Size in bytes for 32-bit ints. @@ -253,14 +252,14 @@ type THeaderTransport struct { // Reading related variables. reader *bufio.Reader // When frame is detected, we read the frame fully into frameBuffer. - frameBuffer bytes.Buffer + frameBuffer *bytes.Buffer // When it's non-nil, Read should read from frameReader instead of // reader, and EOF error indicates end of frame instead of end of all // transport. frameReader io.ReadCloser // Writing related variables - writeBuffer bytes.Buffer + writeBuffer *bytes.Buffer writeTransforms []THeaderTransformID clientType clientType @@ -370,11 +369,14 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error { t.reader.Discard(size32) // Read the frame fully into frameBuffer. - _, err = io.CopyN(&t.frameBuffer, t.reader, int64(frameSize)) + if t.frameBuffer == nil { + t.frameBuffer = getBufFromPool() + } + _, err = io.CopyN(t.frameBuffer, t.reader, int64(frameSize)) if err != nil { return err } - t.frameReader = ioutil.NopCloser(&t.frameBuffer) + t.frameReader = io.NopCloser(t.frameBuffer) // Peek and handle the next 32 bits. buf = t.frameBuffer.Bytes()[:size32] @@ -405,7 +407,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error { // It closes frameReader, and also resets frame related states. func (t *THeaderTransport) endOfFrame() error { defer func() { - t.frameBuffer.Reset() + returnBufToPool(&t.frameBuffer) t.frameReader = nil }() return t.frameReader.Close() @@ -418,7 +420,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e var err error var meta headerMeta - if err = binary.Read(&t.frameBuffer, binary.BigEndian, &meta); err != nil { + if err = binary.Read(t.frameBuffer, binary.BigEndian, &meta); err != nil { return err } frameSize -= headerMetaSize @@ -432,7 +434,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e ) } headerBuf := NewTMemoryBuffer() - _, err = io.CopyN(headerBuf, &t.frameBuffer, headerLength) + _, err = io.CopyN(headerBuf, t.frameBuffer, headerLength) if err != nil { return err } @@ -454,7 +456,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e } if transformCount > 0 { reader := NewTransformReaderWithCapacity( - &t.frameBuffer, + t.frameBuffer, int(transformCount), ) t.frameReader = reader @@ -569,16 +571,19 @@ func (t *THeaderTransport) Read(p []byte) (read int, err error) { // // You need to call Flush to actually write them to the transport. func (t *THeaderTransport) Write(p []byte) (int, error) { + if t.writeBuffer == nil { + t.writeBuffer = getBufFromPool() + } return t.writeBuffer.Write(p) } // Flush writes the appropriate header and the write buffer to the underlying transport. func (t *THeaderTransport) Flush(ctx context.Context) error { - if t.writeBuffer.Len() == 0 { + if t.writeBuffer == nil || t.writeBuffer.Len() == 0 { return nil } - defer t.writeBuffer.Reset() + defer returnBufToPool(&t.writeBuffer) switch t.clientType { default: @@ -628,24 +633,25 @@ func (t *THeaderTransport) Flush(ctx context.Context) error { } } - var payload bytes.Buffer + payload := getBufFromPool() + defer returnBufToPool(&payload) meta := headerMeta{ MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask, SequenceID: t.SequenceID, HeaderLength: uint16(headers.Len() / 4), } - if err := binary.Write(&payload, binary.BigEndian, meta); err != nil { + if err := binary.Write(payload, binary.BigEndian, meta); err != nil { return NewTTransportExceptionFromError(err) } - if _, err := io.Copy(&payload, headers); err != nil { + if _, err := io.Copy(payload, headers); err != nil { return NewTTransportExceptionFromError(err) } - writer, err := NewTransformWriter(&payload, t.writeTransforms) + writer, err := NewTransformWriter(payload, t.writeTransforms) if err != nil { return NewTTransportExceptionFromError(err) } - if _, err := io.Copy(writer, &t.writeBuffer); err != nil { + if _, err := io.Copy(writer, t.writeBuffer); err != nil { return NewTTransportExceptionFromError(err) } if err := writer.Close(); err != nil { @@ -659,7 +665,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error { return NewTTransportExceptionFromError(err) } // Then write the payload - if _, err := io.Copy(t.transport, &payload); err != nil { + if _, err := io.Copy(t.transport, payload); err != nil { return NewTTransportExceptionFromError(err) } @@ -671,7 +677,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error { } fallthrough case clientUnframedBinary, clientUnframedCompact: - if _, err := io.Copy(t.transport, &t.writeBuffer); err != nil { + if _, err := io.Copy(t.transport, t.writeBuffer); err != nil { return NewTTransportExceptionFromError(err) } } diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go index 65e69ee5a..25ba8d3b1 100644 --- a/lib/go/thrift/header_transport_test.go +++ b/lib/go/thrift/header_transport_test.go @@ -23,7 +23,6 @@ import ( "context" "fmt" "io" - "io/ioutil" "strings" "testing" "testing/quick" @@ -87,7 +86,7 @@ func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocol if err := reader.ReadFrame(context.Background()); err != nil { t.Errorf("reader.ReadFrame returned error: %v", err) } - read, err := ioutil.ReadAll(reader) + read, err := io.ReadAll(reader) if err != nil { t.Errorf("Read returned error: %v", err) } @@ -305,3 +304,59 @@ func TestSetTHeaderTransportProtocolID(t *testing.T) { t.Errorf("Expected protocol id %v, got %v", expected, actual) } } + +func TestTHeaderTransportReuseTransport(t *testing.T) { + const ( + content = "Hello, world!" + n = 10 + ) + trans := NewTMemoryBuffer() + reader := NewTHeaderTransport(trans) + writer := NewTHeaderTransport(trans) + + t.Run("pair", func(t *testing.T) { + for i := 0; i < n; i++ { + // write + if _, err := io.Copy(writer, strings.NewReader(content)); err != nil { + t.Fatalf("Failed to write on #%d: %v", i, err) + } + if err := writer.Flush(context.Background()); err != nil { + t.Fatalf("Failed to flush on #%d: %v", i, err) + } + + // read + read, err := io.ReadAll(reader) + if err != nil { + t.Errorf("Failed to read on #%d: %v", i, err) + } + if string(read) != content { + t.Errorf("Read #%d: want %q, got %q", i, content, read) + } + } + }) + + t.Run("batched", func(t *testing.T) { + // write + for i := 0; i < n; i++ { + if _, err := io.Copy(writer, strings.NewReader(content)); err != nil { + t.Fatalf("Failed to write on #%d: %v", i, err) + } + if err := writer.Flush(context.Background()); err != nil { + t.Fatalf("Failed to flush on #%d: %v", i, err) + } + } + + // read + for i := 0; i < n; i++ { + buf := make([]byte, len(content)) + n, err := reader.Read(buf) + if err != nil { + t.Errorf("Failed to read on #%d: %v", i, err) + } + read := string(buf[:n]) + if string(read) != content { + t.Errorf("Read #%d: want %q, got %q", i, content, read) + } + } + }) +} |