summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2021-01-22 15:41:41 -0800
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2021-01-22 20:49:57 -0800
commit8dd04f4adfaea08699b1745c79f122bf9cbd6f07 (patch)
tree8edc3a6eb3ff75b1eab7809429ab12aa39d59c2c
parentd9fcdd3dbafbe1a8296018d0d6c55d972f607a42 (diff)
downloadthrift-8dd04f4adfaea08699b1745c79f122bf9cbd6f07.tar.gz
THRIFT-5322: THeaderTransport protocol id fix
Client: go This fixes a bug introduced in https://github.com/apache/thrift/pull/2296, that we mixed the preferred proto id and the detected proto id, which was a bad idea. This change separates them, so when we propagate TConfiguration, we only change the preferred one, which will only be used for new connections, and leave the detected one from existing connections untouched. Also add a test for it.
-rw-r--r--lib/go/thrift/header_transport.go17
-rw-r--r--lib/go/thrift/header_transport_test.go24
2 files changed, 29 insertions, 12 deletions
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index f1dc99ce3..f5736df42 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -264,6 +264,7 @@ type THeaderTransport struct {
writeTransforms []THeaderTransformID
clientType clientType
+ protocolID THeaderProtocolID
cfg *TConfiguration
// buffer is used in the following scenarios to avoid repetitive
@@ -303,6 +304,7 @@ func NewTHeaderTransportConf(trans TTransport, conf *TConfiguration) *THeaderTra
transport: trans,
reader: bufio.NewReader(trans),
writeHeaders: make(THeaderMap),
+ protocolID: conf.GetTHeaderProtocolID(),
cfg: conf,
}
}
@@ -443,16 +445,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
if err != nil {
return err
}
- idPtr, err := THeaderProtocolIDPtr(THeaderProtocolID(protoID))
- if err != nil {
- return err
- }
- if t.cfg == nil {
- t.cfg = &TConfiguration{
- noPropagation: true,
- }
- }
- t.cfg.THeaderProtocolID = idPtr
+ t.protocolID = THeaderProtocolID(protoID)
var transformCount int32
transformCount, err = hp.readVarint32()
@@ -597,7 +590,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
headers := NewTMemoryBuffer()
hp := NewTCompactProtocol(headers)
hp.SetTConfiguration(t.cfg)
- if _, err := hp.writeVarint32(int32(t.cfg.GetTHeaderProtocolID())); err != nil {
+ if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil {
return NewTTransportExceptionFromError(err)
}
if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); err != nil {
@@ -742,7 +735,7 @@ func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error {
func (t *THeaderTransport) Protocol() THeaderProtocolID {
switch t.clientType {
default:
- return t.cfg.GetTHeaderProtocolID()
+ return t.protocolID
case clientFramedBinary, clientUnframedBinary:
return THeaderProtocolBinary
case clientFramedCompact, clientUnframedCompact:
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 41efb1898..65e69ee5a 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -281,3 +281,27 @@ func BenchmarkTHeaderProtocolIDValidate(b *testing.B) {
})
}
}
+
+func TestSetTHeaderTransportProtocolID(t *testing.T) {
+ const expected = THeaderProtocolCompact
+ factory := NewTHeaderTransportFactoryConf(nil, &TConfiguration{
+ THeaderProtocolID: THeaderProtocolIDPtrMust(expected),
+ })
+ buf := NewTMemoryBuffer()
+ trans, err := factory.GetTransport(buf)
+ if err != nil {
+ t.Fatalf("Failed to get transport from factory: %v", err)
+ }
+ ht, ok := trans.(*THeaderTransport)
+ if !ok {
+ t.Fatalf("Transport is not *THeaderTransport: %#v", trans)
+ }
+ if actual := ht.Protocol(); actual != expected {
+ t.Errorf("Expected protocol id %v, got %v", expected, actual)
+ }
+
+ ht.SetTConfiguration(&TConfiguration{})
+ if actual := ht.Protocol(); actual != expected {
+ t.Errorf("Expected protocol id %v, got %v", expected, actual)
+ }
+}