summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2020-10-10 18:39:32 -0700
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2020-10-14 10:14:03 -0700
commit64c2a4b87ab356e05033045492e51f1ad73a795b (patch)
tree926e9f08771156659c8cf3aa2c086d7592241300
parentdaf620915714b76fce517b376b963440d1f34089 (diff)
downloadthrift-64c2a4b87ab356e05033045492e51f1ad73a795b.tar.gz
THRIFT-5294: Fix panic in go TSimpleJSONProtocol
Client: go In go library's TSimpleJSONProtocol and TJSONProtocol implementations, we use slices as stacks for context info, but didn't do proper boundary check when peeking/popping, result in it might panic with using -1 as slice index in certain cases of calling Write*End without matching Write*Begin before. Refactor the code to properly implement the stack, and return a TProtocolException instead on those cases. Also add unit tests for all protocols. The unit tests shown that TCompactProtocol.[Read|Write]StructEnd would also panic with unmatched Begin calls, so fix them as well.
-rw-r--r--lib/go/thrift/compact_protocol.go7
-rw-r--r--lib/go/thrift/json_protocol.go4
-rw-r--r--lib/go/thrift/json_protocol_test.go4
-rw-r--r--lib/go/thrift/protocol_test.go89
-rw-r--r--lib/go/thrift/simple_json_protocol.go163
-rw-r--r--lib/go/thrift/simple_json_protocol_test.go55
6 files changed, 261 insertions, 61 deletions
diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index 8510f1f79..a0161959c 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -22,6 +22,7 @@ package thrift
import (
"context"
"encoding/binary"
+ "errors"
"fmt"
"io"
"math"
@@ -158,6 +159,9 @@ func (p *TCompactProtocol) WriteStructBegin(ctx context.Context, name string) er
// this as an opportunity to pop the last field from the current struct off
// of the field stack.
func (p *TCompactProtocol) WriteStructEnd(ctx context.Context) error {
+ if len(p.lastField) <= 0 {
+ return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("WriteStructEnd called without matching WriteStructBegin call before"))
+ }
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
@@ -386,6 +390,9 @@ func (p *TCompactProtocol) ReadStructBegin(ctx context.Context) (name string, er
// this struct from the field stack.
func (p *TCompactProtocol) ReadStructEnd(ctx context.Context) error {
// consume the last field we read off the wire.
+ if len(p.lastField) <= 0 {
+ return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("ReadStructEnd called without matching ReadStructBegin call before"))
+ }
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index 9a9328dc7..edc49cc10 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -41,8 +41,8 @@ type TJSONProtocol struct {
// Constructor
func NewTJSONProtocol(t TTransport) *TJSONProtocol {
v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)}
- v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
- v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
+ v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
+ v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
return v
}
diff --git a/lib/go/thrift/json_protocol_test.go b/lib/go/thrift/json_protocol_test.go
index 333d38321..39e52d150 100644
--- a/lib/go/thrift/json_protocol_test.go
+++ b/lib/go/thrift/json_protocol_test.go
@@ -648,3 +648,7 @@ func TestWriteJSONProtocolMap(t *testing.T) {
}
trans.Close()
}
+
+func TestTJSONProtocolUnmatchedBeginEnd(t *testing.T) {
+ UnmatchedBeginEndProtocolTest(t, NewTJSONProtocolFactory())
+}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index c1c67e8ca..caac78e99 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -217,6 +217,10 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
ReadWriteByte(t, p, trans)
trans.Close()
}
+
+ t.Run("UnmatchedBeginEnd", func(t *testing.T) {
+ UnmatchedBeginEndProtocolTest(t, protocolFactory)
+ })
}
func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
@@ -515,3 +519,88 @@ func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) {
}
}
}
+
+func UnmatchedBeginEndProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
+ // NOTE: not all protocol implementations do strict state check to
+ // return an error on unmatched Begin/End calls.
+ // This test is only meant to make sure that those unmatched Begin/End
+ // calls won't cause panic. There's no real "test" here.
+ trans := NewTMemoryBuffer()
+ t.Run("Read", func(t *testing.T) {
+ t.Run("Message", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadMessageEnd(context.Background())
+ p.ReadMessageEnd(context.Background())
+ })
+ t.Run("Struct", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadStructEnd(context.Background())
+ p.ReadStructEnd(context.Background())
+ })
+ t.Run("Field", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadFieldEnd(context.Background())
+ p.ReadFieldEnd(context.Background())
+ })
+ t.Run("Map", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadMapEnd(context.Background())
+ p.ReadMapEnd(context.Background())
+ })
+ t.Run("List", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadListEnd(context.Background())
+ p.ReadListEnd(context.Background())
+ })
+ t.Run("Set", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadSetEnd(context.Background())
+ p.ReadSetEnd(context.Background())
+ })
+ })
+ t.Run("Write", func(t *testing.T) {
+ t.Run("Message", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteMessageEnd(context.Background())
+ p.WriteMessageEnd(context.Background())
+ })
+ t.Run("Struct", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteStructEnd(context.Background())
+ p.WriteStructEnd(context.Background())
+ })
+ t.Run("Field", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteFieldEnd(context.Background())
+ p.WriteFieldEnd(context.Background())
+ })
+ t.Run("Map", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteMapEnd(context.Background())
+ p.WriteMapEnd(context.Background())
+ })
+ t.Run("List", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteListEnd(context.Background())
+ p.WriteListEnd(context.Background())
+ })
+ t.Run("Set", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteSetEnd(context.Background())
+ p.WriteSetEnd(context.Background())
+ })
+ })
+ trans.Close()
+}
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index d101b993c..e94b44bbb 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -25,6 +25,7 @@ import (
"context"
"encoding/base64"
"encoding/json"
+ "errors"
"fmt"
"io"
"math"
@@ -34,12 +35,13 @@ import (
type _ParseContext int
const (
- _CONTEXT_IN_TOPLEVEL _ParseContext = 1
- _CONTEXT_IN_LIST_FIRST _ParseContext = 2
- _CONTEXT_IN_LIST _ParseContext = 3
- _CONTEXT_IN_OBJECT_FIRST _ParseContext = 4
- _CONTEXT_IN_OBJECT_NEXT_KEY _ParseContext = 5
- _CONTEXT_IN_OBJECT_NEXT_VALUE _ParseContext = 6
+ _CONTEXT_INVALID _ParseContext = iota
+ _CONTEXT_IN_TOPLEVEL // 1
+ _CONTEXT_IN_LIST_FIRST // 2
+ _CONTEXT_IN_LIST // 3
+ _CONTEXT_IN_OBJECT_FIRST // 4
+ _CONTEXT_IN_OBJECT_NEXT_KEY // 5
+ _CONTEXT_IN_OBJECT_NEXT_VALUE // 6
)
func (p _ParseContext) String() string {
@@ -60,6 +62,32 @@ func (p _ParseContext) String() string {
return "UNKNOWN-PARSE-CONTEXT"
}
+type jsonContextStack []_ParseContext
+
+func (s *jsonContextStack) push(v _ParseContext) {
+ *s = append(*s, v)
+}
+
+func (s jsonContextStack) peek() (v _ParseContext, ok bool) {
+ l := len(s)
+ if l <= 0 {
+ return
+ }
+ return s[l-1], true
+}
+
+func (s *jsonContextStack) pop() (v _ParseContext, ok bool) {
+ l := len(*s)
+ if l <= 0 {
+ return
+ }
+ v = (*s)[l-1]
+ *s = (*s)[0 : l-1]
+ return v, true
+}
+
+var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Unexpected empty json protocol context stack"))
+
// Simple JSON protocol implementation for thrift.
//
// This protocol produces/consumes a simple output format
@@ -69,8 +97,8 @@ func (p _ParseContext) String() string {
type TSimpleJSONProtocol struct {
trans TTransport
- parseContextStack []int
- dumpContext []int
+ parseContextStack jsonContextStack
+ dumpContext jsonContextStack
writer *bufio.Writer
reader *bufio.Reader
@@ -82,8 +110,8 @@ func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol {
writer: bufio.NewWriter(t),
reader: bufio.NewReader(t),
}
- v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
- v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
+ v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
+ v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
return v
}
@@ -549,41 +577,41 @@ func (p *TSimpleJSONProtocol) Transport() TTransport {
}
func (p *TSimpleJSONProtocol) OutputPreValue() error {
- cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
switch cxt {
case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY:
if _, e := p.write(JSON_COMMA); e != nil {
return NewTProtocolException(e)
}
- break
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
if _, e := p.write(JSON_COLON); e != nil {
return NewTProtocolException(e)
}
- break
}
return nil
}
func (p *TSimpleJSONProtocol) OutputPostValue() error {
- cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
switch cxt {
case _CONTEXT_IN_LIST_FIRST:
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST))
- break
+ p.dumpContext.pop()
+ p.dumpContext.push(_CONTEXT_IN_LIST)
case _CONTEXT_IN_OBJECT_FIRST:
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
- break
+ p.dumpContext.pop()
+ p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_KEY:
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
- break
+ p.dumpContext.pop()
+ p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_KEY))
- break
+ p.dumpContext.pop()
+ p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
}
return nil
}
@@ -598,10 +626,13 @@ func (p *TSimpleJSONProtocol) OutputBool(value bool) error {
} else {
v = string(JSON_FALSE)
}
- switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
+ switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = jsonQuote(v)
- default:
}
if e := p.OutputStringData(v); e != nil {
return e
@@ -631,11 +662,14 @@ func (p *TSimpleJSONProtocol) OutputF64(value float64) error {
} else if math.IsInf(value, -1) {
v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE)
} else {
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
v = strconv.FormatFloat(value, 'g', -1, 64)
- switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+ switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = string(JSON_QUOTE) + v + string(JSON_QUOTE)
- default:
}
}
if e := p.OutputStringData(v); e != nil {
@@ -648,11 +682,14 @@ func (p *TSimpleJSONProtocol) OutputI64(value int64) error {
if e := p.OutputPreValue(); e != nil {
return e
}
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
v := strconv.FormatInt(value, 10)
- switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+ switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = jsonQuote(v)
- default:
}
if e := p.OutputStringData(v); e != nil {
return e
@@ -682,7 +719,7 @@ func (p *TSimpleJSONProtocol) OutputObjectBegin() error {
if _, e := p.write(JSON_LBRACE); e != nil {
return NewTProtocolException(e)
}
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST))
+ p.dumpContext.push(_CONTEXT_IN_OBJECT_FIRST)
return nil
}
@@ -690,7 +727,10 @@ func (p *TSimpleJSONProtocol) OutputObjectEnd() error {
if _, e := p.write(JSON_RBRACE); e != nil {
return NewTProtocolException(e)
}
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
+ _, ok := p.dumpContext.pop()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
if e := p.OutputPostValue(); e != nil {
return e
}
@@ -704,7 +744,7 @@ func (p *TSimpleJSONProtocol) OutputListBegin() error {
if _, e := p.write(JSON_LBRACKET); e != nil {
return NewTProtocolException(e)
}
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST))
+ p.dumpContext.push(_CONTEXT_IN_LIST_FIRST)
return nil
}
@@ -712,7 +752,10 @@ func (p *TSimpleJSONProtocol) OutputListEnd() error {
if _, e := p.write(JSON_RBRACKET); e != nil {
return NewTProtocolException(e)
}
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
+ _, ok := p.dumpContext.pop()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
if e := p.OutputPostValue(); e != nil {
return e
}
@@ -736,7 +779,10 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
if e := p.readNonSignificantWhitespace(); e != nil {
return NewTProtocolException(e)
}
- cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+ cxt, ok := p.parseContextStack.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
b, _ := p.reader.Peek(1)
switch cxt {
case _CONTEXT_IN_LIST:
@@ -755,7 +801,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
- break
case _CONTEXT_IN_OBJECT_NEXT_KEY:
if len(b) > 0 {
switch b[0] {
@@ -772,7 +817,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
- break
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
if len(b) > 0 {
switch b[0] {
@@ -787,7 +831,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
- break
}
return nil
}
@@ -796,20 +839,20 @@ func (p *TSimpleJSONProtocol) ParsePostValue() error {
if e := p.readNonSignificantWhitespace(); e != nil {
return NewTProtocolException(e)
}
- cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+ cxt, ok := p.parseContextStack.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
switch cxt {
case _CONTEXT_IN_LIST_FIRST:
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST))
- break
+ p.parseContextStack.pop()
+ p.parseContextStack.push(_CONTEXT_IN_LIST)
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
- break
+ p.parseContextStack.pop()
+ p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_KEY))
- break
+ p.parseContextStack.pop()
+ p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
}
return nil
}
@@ -962,7 +1005,7 @@ func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) {
}
if len(b) > 0 && b[0] == JSON_LBRACE[0] {
p.reader.ReadByte()
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST))
+ p.parseContextStack.push(_CONTEXT_IN_OBJECT_FIRST)
return false, nil
} else if p.safePeekContains(JSON_NULL) {
return true, nil
@@ -975,7 +1018,7 @@ func (p *TSimpleJSONProtocol) ParseObjectEnd() error {
if isNull, err := p.readIfNull(); isNull || err != nil {
return err
}
- cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+ cxt, _ := p.parseContextStack.peek()
if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != _CONTEXT_IN_OBJECT_NEXT_KEY) {
e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context (%d)", cxt)
return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -993,7 +1036,7 @@ func (p *TSimpleJSONProtocol) ParseObjectEnd() error {
break
}
}
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
+ p.parseContextStack.pop()
return p.ParsePostValue()
}
@@ -1007,7 +1050,7 @@ func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) {
return false, err
}
if len(b) >= 1 && b[0] == JSON_LBRACKET[0] {
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST))
+ p.parseContextStack.push(_CONTEXT_IN_LIST_FIRST)
p.reader.ReadByte()
isNull = false
} else if p.safePeekContains(JSON_NULL) {
@@ -1036,7 +1079,7 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
if isNull, err := p.readIfNull(); isNull || err != nil {
return err
}
- cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+ cxt, _ := p.parseContextStack.peek()
if cxt != _CONTEXT_IN_LIST {
e := fmt.Errorf("Expected to be in the List Context, but not in List Context (%d)", cxt)
return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -1054,8 +1097,10 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
break
}
}
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
- if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == _CONTEXT_IN_TOPLEVEL {
+ p.parseContextStack.pop()
+ if cxt, ok := p.parseContextStack.peek(); !ok {
+ return errEmptyJSONContextStack
+ } else if cxt == _CONTEXT_IN_TOPLEVEL {
return nil
}
return p.ParsePostValue()
@@ -1308,8 +1353,8 @@ func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool {
// Reset the context stack to its initial state.
func (p *TSimpleJSONProtocol) resetContextStack() {
- p.parseContextStack = []int{int(_CONTEXT_IN_TOPLEVEL)}
- p.dumpContext = []int{int(_CONTEXT_IN_TOPLEVEL)}
+ p.parseContextStack = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
+ p.dumpContext = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
}
func (p *TSimpleJSONProtocol) write(b []byte) (int, error) {
diff --git a/lib/go/thrift/simple_json_protocol_test.go b/lib/go/thrift/simple_json_protocol_test.go
index 986fff27e..89753c614 100644
--- a/lib/go/thrift/simple_json_protocol_test.go
+++ b/lib/go/thrift/simple_json_protocol_test.go
@@ -736,3 +736,58 @@ func TestWriteSimpleJSONProtocolSafePeek(t *testing.T) {
t.Fatalf("Should not match at test 3")
}
}
+
+func TestJSONContextStack(t *testing.T) {
+ var stack jsonContextStack
+ t.Run("empty-peek", func(t *testing.T) {
+ v, ok := stack.peek()
+ if ok {
+ t.Error("peek() on empty should return ok: false")
+ }
+ expected := _CONTEXT_INVALID
+ if v != expected {
+ t.Errorf("Expected value from peek() to be %v(%d), got %v(%d)", expected, expected, v, v)
+ }
+ })
+ t.Run("empty-pop", func(t *testing.T) {
+ v, ok := stack.pop()
+ if ok {
+ t.Error("pop() on empty should return ok: false")
+ }
+ expected := _CONTEXT_INVALID
+ if v != expected {
+ t.Errorf("Expected value from pop() to be %v(%d), got %v(%d)", expected, expected, v, v)
+ }
+ })
+ t.Run("push-peek-pop", func(t *testing.T) {
+ expected := _CONTEXT_INVALID
+ stack.push(expected)
+ if len(stack) != 1 {
+ t.Errorf("Expected stack to be as size 1 after push, got %#v", stack)
+ }
+ v, ok := stack.peek()
+ if !ok {
+ t.Error("peek() on non-empty should return ok: true")
+ }
+ if v != expected {
+ t.Errorf("Expected value from peek() to be %v(%d), got %v(%d)", expected, expected, v, v)
+ }
+ if len(stack) != 1 {
+ t.Errorf("Expected peek() to be read-only, got %#v", stack)
+ }
+ v, ok = stack.pop()
+ if !ok {
+ t.Error("pop() on non-empty should return ok: true")
+ }
+ if v != expected {
+ t.Errorf("Expected value from pop() to be %v(%d), got %v(%d)", expected, expected, v, v)
+ }
+ if len(stack) != 0 {
+ t.Errorf("Expected pop() to empty the stack, got %#v", stack)
+ }
+ })
+}
+
+func TestTSimpleJSONProtocolUnmatchedBeginEnd(t *testing.T) {
+ UnmatchedBeginEndProtocolTest(t, NewTSimpleJSONProtocolFactory())
+}