diff options
author | Yuxuan 'fishy' Wang <yuxuan.wang@reddit.com> | 2021-01-22 15:41:41 -0800 |
---|---|---|
committer | Yuxuan 'fishy' Wang <fishywang@gmail.com> | 2021-01-22 20:49:57 -0800 |
commit | 8dd04f4adfaea08699b1745c79f122bf9cbd6f07 (patch) | |
tree | 8edc3a6eb3ff75b1eab7809429ab12aa39d59c2c | |
parent | d9fcdd3dbafbe1a8296018d0d6c55d972f607a42 (diff) | |
download | thrift-8dd04f4adfaea08699b1745c79f122bf9cbd6f07.tar.gz |
THRIFT-5322: THeaderTransport protocol id fix
Client: go
This fixes a bug introduced in
https://github.com/apache/thrift/pull/2296, that we mixed the preferred
proto id and the detected proto id, which was a bad idea.
This change separates them, so when we propagate TConfiguration, we only
change the preferred one, which will only be used for new connections,
and leave the detected one from existing connections untouched.
Also add a test for it.
-rw-r--r-- | lib/go/thrift/header_transport.go | 17 | ||||
-rw-r--r-- | lib/go/thrift/header_transport_test.go | 24 |
2 files changed, 29 insertions, 12 deletions
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go index f1dc99ce3..f5736df42 100644 --- a/lib/go/thrift/header_transport.go +++ b/lib/go/thrift/header_transport.go @@ -264,6 +264,7 @@ type THeaderTransport struct { writeTransforms []THeaderTransformID clientType clientType + protocolID THeaderProtocolID cfg *TConfiguration // buffer is used in the following scenarios to avoid repetitive @@ -303,6 +304,7 @@ func NewTHeaderTransportConf(trans TTransport, conf *TConfiguration) *THeaderTra transport: trans, reader: bufio.NewReader(trans), writeHeaders: make(THeaderMap), + protocolID: conf.GetTHeaderProtocolID(), cfg: conf, } } @@ -443,16 +445,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e if err != nil { return err } - idPtr, err := THeaderProtocolIDPtr(THeaderProtocolID(protoID)) - if err != nil { - return err - } - if t.cfg == nil { - t.cfg = &TConfiguration{ - noPropagation: true, - } - } - t.cfg.THeaderProtocolID = idPtr + t.protocolID = THeaderProtocolID(protoID) var transformCount int32 transformCount, err = hp.readVarint32() @@ -597,7 +590,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error { headers := NewTMemoryBuffer() hp := NewTCompactProtocol(headers) hp.SetTConfiguration(t.cfg) - if _, err := hp.writeVarint32(int32(t.cfg.GetTHeaderProtocolID())); err != nil { + if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil { return NewTTransportExceptionFromError(err) } if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); err != nil { @@ -742,7 +735,7 @@ func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error { func (t *THeaderTransport) Protocol() THeaderProtocolID { switch t.clientType { default: - return t.cfg.GetTHeaderProtocolID() + return t.protocolID case clientFramedBinary, clientUnframedBinary: return THeaderProtocolBinary case clientFramedCompact, clientUnframedCompact: diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go index 41efb1898..65e69ee5a 100644 --- a/lib/go/thrift/header_transport_test.go +++ b/lib/go/thrift/header_transport_test.go @@ -281,3 +281,27 @@ func BenchmarkTHeaderProtocolIDValidate(b *testing.B) { }) } } + +func TestSetTHeaderTransportProtocolID(t *testing.T) { + const expected = THeaderProtocolCompact + factory := NewTHeaderTransportFactoryConf(nil, &TConfiguration{ + THeaderProtocolID: THeaderProtocolIDPtrMust(expected), + }) + buf := NewTMemoryBuffer() + trans, err := factory.GetTransport(buf) + if err != nil { + t.Fatalf("Failed to get transport from factory: %v", err) + } + ht, ok := trans.(*THeaderTransport) + if !ok { + t.Fatalf("Transport is not *THeaderTransport: %#v", trans) + } + if actual := ht.Protocol(); actual != expected { + t.Errorf("Expected protocol id %v, got %v", expected, actual) + } + + ht.SetTConfiguration(&TConfiguration{}) + if actual := ht.Protocol(); actual != expected { + t.Errorf("Expected protocol id %v, got %v", expected, actual) + } +} |