summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2020-12-10 14:42:37 -0800
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2020-12-16 09:31:18 -0800
commit37c2ceb737cb40377346c63a05f407da1c119ba0 (patch)
tree69efd7065e81348fd22f8fbc07ec475846e8a111
parentdda80547b10d698784713eb62a04f6f42eae107b (diff)
downloadthrift-37c2ceb737cb40377346c63a05f407da1c119ba0.tar.gz
THRIFT-5322: Guard against large string/binary lengths in Go
Client: go In TBinaryProtocol.ReadString, TBinaryProtocol.ReadBinary, TCompactProtocol.ReadString, and TCompactProtocol.ReadBinary, use safeReadBytes to prevent from large allocation on malformed sizes. $ go test -bench=SafeReadBytes -benchmem BenchmarkSafeReadBytes/normal-12 625057 1789 ns/op 2176 B/op 5 allocs/op BenchmarkSafeReadBytes/max-askedSize-12 545271 2236 ns/op 14464 B/op 7 allocs/op PASS
-rw-r--r--lib/go/thrift/binary_protocol.go56
-rw-r--r--lib/go/thrift/binary_protocol_test.go153
-rw-r--r--lib/go/thrift/compact_protocol.go17
3 files changed, 184 insertions, 42 deletions
diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go
index c87d23a1b..58956f673 100644
--- a/lib/go/thrift/binary_protocol.go
+++ b/lib/go/thrift/binary_protocol.go
@@ -432,6 +432,15 @@ func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err err
err = invalidDataLength
return
}
+ if size == 0 {
+ return "", nil
+ }
+ if size < int32(len(p.buffer)) {
+ // Avoid allocation on small reads
+ buf := p.buffer[:size]
+ read, e := io.ReadFull(p.trans, buf)
+ return string(buf[:read]), NewTProtocolException(e)
+ }
return p.readStringBody(size)
}
@@ -445,9 +454,7 @@ func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
return nil, invalidDataLength
}
- isize := int(size)
- buf := make([]byte, isize)
- _, err := io.ReadFull(p.trans, buf)
+ buf, err := safeReadBytes(size, p.trans)
return buf, NewTProtocolException(err)
}
@@ -479,38 +486,21 @@ func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) {
return NewTProtocolException(err)
}
-const readLimit = 32768
-
func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
- if size < 0 {
- return "", nil
- }
-
- var (
- buf bytes.Buffer
- e error
- b []byte
- )
+ buf, err := safeReadBytes(size, p.trans)
+ return string(buf), NewTProtocolException(err)
+}
- switch {
- case int(size) <= len(p.buffer):
- b = p.buffer[:size] // avoids allocation for small reads
- case int(size) < readLimit:
- b = make([]byte, size)
- default:
- b = make([]byte, readLimit)
+// This function is shared between TBinaryProtocol and TCompactProtocol.
+//
+// It tries to read size bytes from trans, in a way that prevents large
+// allocations when size is insanely large (mostly caused by malformed message).
+func safeReadBytes(size int32, trans io.Reader) ([]byte, error) {
+ if size < 0 {
+ return nil, nil
}
- for size > 0 {
- _, e = io.ReadFull(p.trans, b)
- buf.Write(b)
- if e != nil {
- break
- }
- size -= readLimit
- if size < readLimit && size > 0 {
- b = b[:size]
- }
- }
- return buf.String(), NewTProtocolException(e)
+ buf := new(bytes.Buffer)
+ _, err := io.CopyN(buf, trans, int64(size))
+ return buf.Bytes(), err
}
diff --git a/lib/go/thrift/binary_protocol_test.go b/lib/go/thrift/binary_protocol_test.go
index 0462cc79d..88bfd26b7 100644
--- a/lib/go/thrift/binary_protocol_test.go
+++ b/lib/go/thrift/binary_protocol_test.go
@@ -20,9 +20,162 @@
package thrift
import (
+ "bytes"
+ "math"
+ "strings"
"testing"
)
func TestReadWriteBinaryProtocol(t *testing.T) {
ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault())
}
+
+const (
+ safeReadBytesSource = `
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer sit amet
+tincidunt nibh. Phasellus vel convallis libero, sit amet posuere quam. Nullam
+blandit velit at nibh fringilla, sed egestas erat dapibus. Sed hendrerit
+tincidunt accumsan. Curabitur consectetur bibendum dui nec hendrerit. Fusce quis
+turpis nec magna efficitur volutpat a ut nibh. Vestibulum odio risus, tristique
+a nisi et, congue mattis mi. Vivamus a nunc justo. Mauris molestie sagittis
+magna, hendrerit auctor lectus egestas non. Phasellus pretium, odio sit amet
+bibendum feugiat, velit nunc luctus erat, ac bibendum mi dui molestie nulla.
+Nullam fermentum magna eu elit vehicula tincidunt. Etiam ornare laoreet
+dignissim. Ut sed nunc ac neque vulputate fermentum. Morbi volutpat dapibus
+magna, at porttitor quam facilisis a. Donec eget fermentum risus. Aliquam erat
+volutpat.
+
+Phasellus molestie id ante vel iaculis. Fusce eget quam nec quam viverra laoreet
+vitae a dui. Mauris blandit blandit dui, iaculis interdum diam mollis at. Morbi
+vel sem et.
+`
+ safeReadBytesSourceLen = len(safeReadBytesSource)
+)
+
+func TestSafeReadBytes(t *testing.T) {
+ srcData := []byte(safeReadBytesSource)
+
+ for _, c := range []struct {
+ label string
+ askedSize int32
+ dataSize int
+ }{
+ {
+ label: "normal",
+ askedSize: 100,
+ dataSize: 100,
+ },
+ {
+ label: "max-askedSize",
+ askedSize: math.MaxInt32,
+ dataSize: safeReadBytesSourceLen,
+ },
+ } {
+ t.Run(c.label, func(t *testing.T) {
+ data := bytes.NewReader(srcData[:c.dataSize])
+ buf, err := safeReadBytes(c.askedSize, data)
+ if len(buf) != c.dataSize {
+ t.Errorf(
+ "Expected to read %d bytes, got %d",
+ c.dataSize,
+ len(buf),
+ )
+ }
+ if !strings.HasPrefix(safeReadBytesSource, string(buf)) {
+ t.Errorf("Unexpected read data: %q", buf)
+ }
+ if int32(c.dataSize) < c.askedSize {
+ // We expect error in this case
+ if err == nil {
+ t.Errorf(
+ "Expected error when dataSize %d < askedSize %d, got nil",
+ c.dataSize,
+ c.askedSize,
+ )
+ }
+ } else {
+ // We expect no error in this case
+ if err != nil {
+ t.Errorf(
+ "Expected no error when dataSize %d >= askedSize %d, got: %v",
+ c.dataSize,
+ c.askedSize,
+ err,
+ )
+ }
+ }
+ })
+ }
+}
+
+func generateSafeReadBytesBenchmark(askedSize int32, dataSize int) func(b *testing.B) {
+ return func(b *testing.B) {
+ data := make([]byte, dataSize)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ safeReadBytes(askedSize, bytes.NewReader(data))
+ }
+ }
+}
+
+func TestSafeReadBytesAlloc(t *testing.T) {
+ if testing.Short() {
+ // NOTE: Since this test runs a benchmark test, it takes at
+ // least 1 second.
+ //
+ // In general we try to avoid unit tests taking that long to run,
+ // but it's to verify a security issue so we made an exception
+ // here:
+ // https://issues.apache.org/jira/browse/THRIFT-5322
+ t.Skip("skipping test in short mode.")
+ }
+
+ const (
+ askedSize = int32(math.MaxInt32)
+ dataSize = 4096
+ )
+
+ // The purpose of this test is that in the case a string header says
+ // that it has a string askedSize bytes long, the implementation should
+ // not just allocate askedSize bytes upfront. So when there're actually
+ // not enough data to be read (dataSize), the actual allocated bytes
+ // should be somewhere between dataSize and askedSize.
+ //
+ // Different approachs could have different memory overheads, so this
+ // target is arbitrary in nature. But when dataSize is small enough
+ // compare to askedSize, half the askedSize is a good and safe target.
+ const target = int64(askedSize) / 2
+
+ bm := testing.Benchmark(generateSafeReadBytesBenchmark(askedSize, dataSize))
+ actual := bm.AllocedBytesPerOp()
+ if actual > target {
+ t.Errorf(
+ "Expected allocated bytes per op to be <= %d, got %d",
+ target,
+ actual,
+ )
+ } else {
+ t.Logf("Allocated bytes: %d B/op", actual)
+ }
+}
+
+func BenchmarkSafeReadBytes(b *testing.B) {
+ for _, c := range []struct {
+ label string
+ askedSize int32
+ dataSize int
+ }{
+ {
+ label: "normal",
+ askedSize: 100,
+ dataSize: 100,
+ },
+ {
+ label: "max-askedSize",
+ askedSize: math.MaxInt32,
+ dataSize: 4096,
+ },
+ } {
+ b.Run(c.label, generateSafeReadBytesBenchmark(c.askedSize, c.dataSize))
+ }
+}
diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index a0161959c..424906d61 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -579,17 +579,17 @@ func (p *TCompactProtocol) ReadString(ctx context.Context) (value string, err er
if length < 0 {
return "", invalidDataLength
}
-
if length == 0 {
return "", nil
}
- var buf []byte
- if length <= int32(len(p.buffer)) {
- buf = p.buffer[0:length]
- } else {
- buf = make([]byte, length)
+ if length < int32(len(p.buffer)) {
+ // Avoid allocation on small reads
+ buf := p.buffer[:length]
+ read, e := io.ReadFull(p.trans, buf)
+ return string(buf[:read]), NewTProtocolException(e)
}
- _, e = io.ReadFull(p.trans, buf)
+
+ buf, e := safeReadBytes(length, p.trans)
return string(buf), NewTProtocolException(e)
}
@@ -606,8 +606,7 @@ func (p *TCompactProtocol) ReadBinary(ctx context.Context) (value []byte, err er
return nil, invalidDataLength
}
- buf := make([]byte, length)
- _, e = io.ReadFull(p.trans, buf)
+ buf, e := safeReadBytes(length, p.trans)
return buf, NewTProtocolException(e)
}