diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2020-12-10 14:42:37 -0800 |
---|---|---|
committer | Yuxuan 'fishy' Wang <fishywang@gmail.com> | 2020-12-16 09:31:18 -0800 |
commit | 37c2ceb737cb40377346c63a05f407da1c119ba0 (patch) | |
tree | 69efd7065e81348fd22f8fbc07ec475846e8a111 | |
parent | dda80547b10d698784713eb62a04f6f42eae107b (diff) | |
download | thrift-37c2ceb737cb40377346c63a05f407da1c119ba0.tar.gz |
THRIFT-5322: Guard against large string/binary lengths in Go
Client: go
In TBinaryProtocol.ReadString, TBinaryProtocol.ReadBinary,
TCompactProtocol.ReadString, and TCompactProtocol.ReadBinary, use
safeReadBytes to prevent from large allocation on malformed sizes.
$ go test -bench=SafeReadBytes -benchmem
BenchmarkSafeReadBytes/normal-12 625057 1789 ns/op 2176 B/op 5 allocs/op
BenchmarkSafeReadBytes/max-askedSize-12 545271 2236 ns/op 14464 B/op 7 allocs/op
PASS
-rw-r--r-- | lib/go/thrift/binary_protocol.go | 56 | ||||
-rw-r--r-- | lib/go/thrift/binary_protocol_test.go | 153 | ||||
-rw-r--r-- | lib/go/thrift/compact_protocol.go | 17 |
3 files changed, 184 insertions, 42 deletions
diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go index c87d23a1b..58956f673 100644 --- a/lib/go/thrift/binary_protocol.go +++ b/lib/go/thrift/binary_protocol.go @@ -432,6 +432,15 @@ func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err err err = invalidDataLength return } + if size == 0 { + return "", nil + } + if size < int32(len(p.buffer)) { + // Avoid allocation on small reads + buf := p.buffer[:size] + read, e := io.ReadFull(p.trans, buf) + return string(buf[:read]), NewTProtocolException(e) + } return p.readStringBody(size) } @@ -445,9 +454,7 @@ func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) { return nil, invalidDataLength } - isize := int(size) - buf := make([]byte, isize) - _, err := io.ReadFull(p.trans, buf) + buf, err := safeReadBytes(size, p.trans) return buf, NewTProtocolException(err) } @@ -479,38 +486,21 @@ func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) { return NewTProtocolException(err) } -const readLimit = 32768 - func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) { - if size < 0 { - return "", nil - } - - var ( - buf bytes.Buffer - e error - b []byte - ) + buf, err := safeReadBytes(size, p.trans) + return string(buf), NewTProtocolException(err) +} - switch { - case int(size) <= len(p.buffer): - b = p.buffer[:size] // avoids allocation for small reads - case int(size) < readLimit: - b = make([]byte, size) - default: - b = make([]byte, readLimit) +// This function is shared between TBinaryProtocol and TCompactProtocol. +// +// It tries to read size bytes from trans, in a way that prevents large +// allocations when size is insanely large (mostly caused by malformed message). +func safeReadBytes(size int32, trans io.Reader) ([]byte, error) { + if size < 0 { + return nil, nil } - for size > 0 { - _, e = io.ReadFull(p.trans, b) - buf.Write(b) - if e != nil { - break - } - size -= readLimit - if size < readLimit && size > 0 { - b = b[:size] - } - } - return buf.String(), NewTProtocolException(e) + buf := new(bytes.Buffer) + _, err := io.CopyN(buf, trans, int64(size)) + return buf.Bytes(), err } diff --git a/lib/go/thrift/binary_protocol_test.go b/lib/go/thrift/binary_protocol_test.go index 0462cc79d..88bfd26b7 100644 --- a/lib/go/thrift/binary_protocol_test.go +++ b/lib/go/thrift/binary_protocol_test.go @@ -20,9 +20,162 @@ package thrift import ( + "bytes" + "math" + "strings" "testing" ) func TestReadWriteBinaryProtocol(t *testing.T) { ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault()) } + +const ( + safeReadBytesSource = ` +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer sit amet +tincidunt nibh. Phasellus vel convallis libero, sit amet posuere quam. Nullam +blandit velit at nibh fringilla, sed egestas erat dapibus. Sed hendrerit +tincidunt accumsan. Curabitur consectetur bibendum dui nec hendrerit. Fusce quis +turpis nec magna efficitur volutpat a ut nibh. Vestibulum odio risus, tristique +a nisi et, congue mattis mi. Vivamus a nunc justo. Mauris molestie sagittis +magna, hendrerit auctor lectus egestas non. Phasellus pretium, odio sit amet +bibendum feugiat, velit nunc luctus erat, ac bibendum mi dui molestie nulla. +Nullam fermentum magna eu elit vehicula tincidunt. Etiam ornare laoreet +dignissim. Ut sed nunc ac neque vulputate fermentum. Morbi volutpat dapibus +magna, at porttitor quam facilisis a. Donec eget fermentum risus. Aliquam erat +volutpat. + +Phasellus molestie id ante vel iaculis. Fusce eget quam nec quam viverra laoreet +vitae a dui. Mauris blandit blandit dui, iaculis interdum diam mollis at. Morbi +vel sem et. +` + safeReadBytesSourceLen = len(safeReadBytesSource) +) + +func TestSafeReadBytes(t *testing.T) { + srcData := []byte(safeReadBytesSource) + + for _, c := range []struct { + label string + askedSize int32 + dataSize int + }{ + { + label: "normal", + askedSize: 100, + dataSize: 100, + }, + { + label: "max-askedSize", + askedSize: math.MaxInt32, + dataSize: safeReadBytesSourceLen, + }, + } { + t.Run(c.label, func(t *testing.T) { + data := bytes.NewReader(srcData[:c.dataSize]) + buf, err := safeReadBytes(c.askedSize, data) + if len(buf) != c.dataSize { + t.Errorf( + "Expected to read %d bytes, got %d", + c.dataSize, + len(buf), + ) + } + if !strings.HasPrefix(safeReadBytesSource, string(buf)) { + t.Errorf("Unexpected read data: %q", buf) + } + if int32(c.dataSize) < c.askedSize { + // We expect error in this case + if err == nil { + t.Errorf( + "Expected error when dataSize %d < askedSize %d, got nil", + c.dataSize, + c.askedSize, + ) + } + } else { + // We expect no error in this case + if err != nil { + t.Errorf( + "Expected no error when dataSize %d >= askedSize %d, got: %v", + c.dataSize, + c.askedSize, + err, + ) + } + } + }) + } +} + +func generateSafeReadBytesBenchmark(askedSize int32, dataSize int) func(b *testing.B) { + return func(b *testing.B) { + data := make([]byte, dataSize) + b.ResetTimer() + for i := 0; i < b.N; i++ { + safeReadBytes(askedSize, bytes.NewReader(data)) + } + } +} + +func TestSafeReadBytesAlloc(t *testing.T) { + if testing.Short() { + // NOTE: Since this test runs a benchmark test, it takes at + // least 1 second. + // + // In general we try to avoid unit tests taking that long to run, + // but it's to verify a security issue so we made an exception + // here: + // https://issues.apache.org/jira/browse/THRIFT-5322 + t.Skip("skipping test in short mode.") + } + + const ( + askedSize = int32(math.MaxInt32) + dataSize = 4096 + ) + + // The purpose of this test is that in the case a string header says + // that it has a string askedSize bytes long, the implementation should + // not just allocate askedSize bytes upfront. So when there're actually + // not enough data to be read (dataSize), the actual allocated bytes + // should be somewhere between dataSize and askedSize. + // + // Different approachs could have different memory overheads, so this + // target is arbitrary in nature. But when dataSize is small enough + // compare to askedSize, half the askedSize is a good and safe target. + const target = int64(askedSize) / 2 + + bm := testing.Benchmark(generateSafeReadBytesBenchmark(askedSize, dataSize)) + actual := bm.AllocedBytesPerOp() + if actual > target { + t.Errorf( + "Expected allocated bytes per op to be <= %d, got %d", + target, + actual, + ) + } else { + t.Logf("Allocated bytes: %d B/op", actual) + } +} + +func BenchmarkSafeReadBytes(b *testing.B) { + for _, c := range []struct { + label string + askedSize int32 + dataSize int + }{ + { + label: "normal", + askedSize: 100, + dataSize: 100, + }, + { + label: "max-askedSize", + askedSize: math.MaxInt32, + dataSize: 4096, + }, + } { + b.Run(c.label, generateSafeReadBytesBenchmark(c.askedSize, c.dataSize)) + } +} diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go index a0161959c..424906d61 100644 --- a/lib/go/thrift/compact_protocol.go +++ b/lib/go/thrift/compact_protocol.go @@ -579,17 +579,17 @@ func (p *TCompactProtocol) ReadString(ctx context.Context) (value string, err er if length < 0 { return "", invalidDataLength } - if length == 0 { return "", nil } - var buf []byte - if length <= int32(len(p.buffer)) { - buf = p.buffer[0:length] - } else { - buf = make([]byte, length) + if length < int32(len(p.buffer)) { + // Avoid allocation on small reads + buf := p.buffer[:length] + read, e := io.ReadFull(p.trans, buf) + return string(buf[:read]), NewTProtocolException(e) } - _, e = io.ReadFull(p.trans, buf) + + buf, e := safeReadBytes(length, p.trans) return string(buf), NewTProtocolException(e) } @@ -606,8 +606,7 @@ func (p *TCompactProtocol) ReadBinary(ctx context.Context) (value []byte, err er return nil, invalidDataLength } - buf := make([]byte, length) - _, e = io.ReadFull(p.trans, buf) + buf, e := safeReadBytes(length, p.trans) return buf, NewTProtocolException(e) } |