From 57e24caa86afa8bacf444e66a9aef6203831416c Mon Sep 17 00:00:00 2001 From: Yuxuan 'fishy' Wang Date: Thu, 25 Mar 2021 17:00:31 -0700 Subject: THRIFT-5369: Use MaxMessageSize to check container sizes Client: go --- CHANGES.md | 4 +++ lib/go/thrift/binary_protocol.go | 19 ++++------- lib/go/thrift/compact_protocol.go | 12 +++---- lib/go/thrift/json_protocol.go | 21 +++++++++--- lib/go/thrift/simple_json_protocol.go | 63 ++++++++++++++++++++++++++++++----- 5 files changed, 87 insertions(+), 32 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index f1fadcd9d..d7ef279b7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,10 @@ - [THRIFT-5383](https://issues.apache.org/jira/browse/THRIFT-5383) - THRIFT-5383 TJSONProtocol Java readString throws on bounds check +### Go + +- [THRIFT-5369](https://issues.apache.org/jira/browse/THRIFT-5369) - No longer pre-allocating the whole container (map/set/list) in compiled go code to avoid huge allocations on malformed messages + ## 0.14.1 diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go index 45c880d32..3ed6608ee 100644 --- a/lib/go/thrift/binary_protocol.go +++ b/lib/go/thrift/binary_protocol.go @@ -23,7 +23,6 @@ import ( "bytes" "context" "encoding/binary" - "errors" "fmt" "io" "math" @@ -334,8 +333,6 @@ func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error { return nil } -var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length")) - func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) { k, e := p.ReadByte(ctx) if e != nil { @@ -354,8 +351,8 @@ func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, err = NewTProtocolException(e) return } - if size32 < 0 { - err = invalidDataLength + err = checkSizeForProtocol(size32, p.cfg) + if err != nil { return } size = int(size32) @@ -378,8 +375,8 @@ func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, si err = NewTProtocolException(e) return } - if size32 < 0 { - err = invalidDataLength + err = checkSizeForProtocol(size32, p.cfg) + if err != nil { return } size = int(size32) @@ -403,8 +400,8 @@ func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, siz err = NewTProtocolException(e) return } - if size32 < 0 { - err = invalidDataLength + err = checkSizeForProtocol(size32, p.cfg) + if err != nil { return } size = int(size32) @@ -466,10 +463,6 @@ func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err err if err != nil { return } - if size < 0 { - err = invalidDataLength - return - } if size == 0 { return "", nil } diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go index a49225dab..e0de07700 100644 --- a/lib/go/thrift/compact_protocol.go +++ b/lib/go/thrift/compact_protocol.go @@ -477,8 +477,8 @@ func (p *TCompactProtocol) ReadMapBegin(ctx context.Context) (keyType TType, val err = NewTProtocolException(e) return } - if size32 < 0 { - err = invalidDataLength + err = checkSizeForProtocol(size32, p.cfg) + if err != nil { return } size = int(size32) @@ -513,12 +513,12 @@ func (p *TCompactProtocol) ReadListBegin(ctx context.Context) (elemType TType, s err = NewTProtocolException(e) return } - if size2 < 0 { - err = invalidDataLength - return - } size = int(size2) } + err = checkSizeForProtocol(size32, p.cfg) + if err != nil { + return + } elemType, e := p.getTType(tCompactType(size_and_type)) if e != nil { err = NewTProtocolException(e) diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go index 8e59d16cf..98764fa88 100644 --- a/lib/go/thrift/json_protocol.go +++ b/lib/go/thrift/json_protocol.go @@ -311,9 +311,13 @@ func (p *TJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueT } // read size - iSize, e := p.ReadI64(ctx) - if e != nil { - return keyType, valueType, size, e + iSize, err := p.ReadI64(ctx) + if err != nil { + return keyType, valueType, size, err + } + err = checkSizeForProtocol(int32(iSize), p.cfg) + if err != nil { + return keyType, valueType, 0, err } size = int(iSize) @@ -485,9 +489,16 @@ func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) if err != nil { return elemType, size, err } - nSize, _, err2 := p.ParseI64() + nSize, _, err := p.ParseI64() + if err != nil { + return elemType, 0, err + } + err = checkSizeForProtocol(int32(nSize), p.cfg) + if err != nil { + return elemType, 0, err + } size = int(nSize) - return elemType, size, err2 + return elemType, size, nil } func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) { diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go index d1a815453..4967532c5 100644 --- a/lib/go/thrift/simple_json_protocol.go +++ b/lib/go/thrift/simple_json_protocol.go @@ -97,6 +97,8 @@ var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, error type TSimpleJSONProtocol struct { trans TTransport + cfg *TConfiguration + parseContextStack jsonContextStack dumpContext jsonContextStack @@ -104,9 +106,18 @@ type TSimpleJSONProtocol struct { reader *bufio.Reader } -// Constructor +// Deprecated: Use NewTSimpleJSONProtocolConf instead.: func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol { - v := &TSimpleJSONProtocol{trans: t, + return NewTSimpleJSONProtocolConf(t, &TConfiguration{ + noPropagation: true, + }) +} + +func NewTSimpleJSONProtocolConf(t TTransport, conf *TConfiguration) *TSimpleJSONProtocol { + PropagateTConfiguration(t, conf) + v := &TSimpleJSONProtocol{ + trans: t, + cfg: conf, writer: bufio.NewWriter(t), reader: bufio.NewReader(t), } @@ -116,14 +127,32 @@ func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol { } // Factory -type TSimpleJSONProtocolFactory struct{} +type TSimpleJSONProtocolFactory struct { + cfg *TConfiguration +} func (p *TSimpleJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { - return NewTSimpleJSONProtocol(trans) + return NewTSimpleJSONProtocolConf(trans, p.cfg) } +// SetTConfiguration implements TConfigurationSetter for propagation. +func (p *TSimpleJSONProtocolFactory) SetTConfiguration(conf *TConfiguration) { + p.cfg = conf +} + +// Deprecated: Use NewTSimpleJSONProtocolFactoryConf instead. func NewTSimpleJSONProtocolFactory() *TSimpleJSONProtocolFactory { - return &TSimpleJSONProtocolFactory{} + return &TSimpleJSONProtocolFactory{ + cfg: &TConfiguration{ + noPropagation: true, + }, + } +} + +func NewTSimpleJSONProtocolFactoryConf(conf *TConfiguration) *TSimpleJSONProtocolFactory { + return &TSimpleJSONProtocolFactory{ + cfg: conf, + } } var ( @@ -399,6 +428,13 @@ func (p *TSimpleJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, // read size iSize, err := p.ReadI64(ctx) + if err != nil { + return keyType, valueType, 0, err + } + err = checkSizeForProtocol(int32(size), p.cfg) + if err != nil { + return keyType, valueType, 0, err + } size = int(iSize) return keyType, valueType, size, err } @@ -1070,9 +1106,16 @@ func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e if err != nil { return elemType, size, err } - nSize, _, err2 := p.ParseI64() + nSize, _, err := p.ParseI64() + if err != nil { + return elemType, 0, err + } + err = checkSizeForProtocol(int32(nSize), p.cfg) + if err != nil { + return elemType, 0, err + } size = int(nSize) - return elemType, size, err2 + return elemType, size, nil } func (p *TSimpleJSONProtocol) ParseListEnd() error { @@ -1368,6 +1411,10 @@ func (p *TSimpleJSONProtocol) write(b []byte) (int, error) { // SetTConfiguration implements TConfigurationSetter for propagation. func (p *TSimpleJSONProtocol) SetTConfiguration(conf *TConfiguration) { PropagateTConfiguration(p.trans, conf) + p.cfg = conf } -var _ TConfigurationSetter = (*TSimpleJSONProtocol)(nil) +var ( + _ TConfigurationSetter = (*TSimpleJSONProtocol)(nil) + _ TConfigurationSetter = (*TSimpleJSONProtocolFactory)(nil) +) -- cgit v1.2.1