summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>2020-12-16 17:10:48 -0800
committerYuxuan 'fishy' Wang <fishywang@gmail.com>2021-01-17 12:24:13 -0800
commitc4d1c0d80067986dbee124887bcb402ee1c6538e (patch)
tree60d2edf616dfd2f1ba9906bd10c9321ff349b4ae
parentc4e899a6d64aa97430ec9f7608d38db2095f6159 (diff)
downloadthrift-c4d1c0d80067986dbee124887bcb402ee1c6538e.tar.gz
THRIFT-5322: Implement TConfiguration in Go library
Client: go Define TConfiguration following the spec, and also move the following configurations scattered around different TTransport/TProtocol into it: - connect and socket timeouts for TSocket and TSSLSocket - tls config for TSSLSocket - max frame size for TFramedTransport - strict read and strict write for TBinaryProtocol - proto id for THeaderTransport Also add TConfiguration support for the following and their factories: - THeaderTransport and THeaderProtocol - TBinaryProtocol - TCompactProtocol - TFramedTransport - TSocket - TSSLSocket Also define TConfigurationSetter interface for easier TConfiguration propagation between wrapped TTransports/TProtocols , and add implementations to the following for propagation (they don't use anything from TConfiguration themselves): - StreamTransport - TBufferedTransport - TDebugProtocol - TJSONProtocol - TSimpleJSONProtocol - TZlibTransport TConfigurationSetter are not implemented by the factories of the "propagation only" TTransports/TProtocols, if they have a factory. For those use cases, TTransportFactoryConf and TProtocolFactoryConf are provided to wrap a factory with the ability to propagate TConfiguration. Also add simple sanity check for TBinaryProtocol and TCompactProtocol's ReadString and ReadBinary functions. Currently it only report error if the header length is larger than MaxMessageSize configured in TConfiguration, for simplicity.
-rw-r--r--CHANGES.md1
-rw-r--r--lib/go/thrift/binary_protocol.go75
-rw-r--r--lib/go/thrift/buffered_transport.go7
-rw-r--r--lib/go/thrift/compact_protocol.go61
-rw-r--r--lib/go/thrift/configuration.go378
-rw-r--r--lib/go/thrift/configuration_test.go338
-rw-r--r--lib/go/thrift/debug_protocol.go8
-rw-r--r--lib/go/thrift/framed_transport.go74
-rw-r--r--lib/go/thrift/header_protocol.go93
-rw-r--r--lib/go/thrift/header_protocol_test.go8
-rw-r--r--lib/go/thrift/header_transport.go103
-rw-r--r--lib/go/thrift/header_transport_test.go24
-rw-r--r--lib/go/thrift/iostream_transport.go8
-rw-r--r--lib/go/thrift/json_protocol.go2
-rw-r--r--lib/go/thrift/simple_json_protocol.go7
-rw-r--r--lib/go/thrift/socket.go103
-rw-r--r--lib/go/thrift/ssl_socket.go119
-rw-r--r--lib/go/thrift/zlib_transport.go7
18 files changed, 1234 insertions, 182 deletions
diff --git a/CHANGES.md b/CHANGES.md
index 65ed07f50..663c4c18c 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -29,6 +29,7 @@
- [THRIFT-5164](https://issues.apache.org/jira/browse/THRIFT-5164) - Add ProcessorMiddleware function type and WrapProcessor function to support wrapping a TProcessor with middleware functions.
- [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - Add context deadline check to ReadMessageBegin in TBinaryProtocol, TCompactProtocol, and THeaderProtocol.
- [THRIFT-5240](https://issues.apache.org/jira/browse/THRIFT-5240) - The context passed into server handler implementations will be canceled when we detected that the client closed the connection.
+- [THRIFT-5322](https://issues.apache.org/jira/browse/THRIFT-5322) - Add support to TConfiguration, and also fix a bug that could cause excessive memory usage when reading malformed messages from TCompactProtocol.
## 0.13.0
diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go
index 58956f673..45c880d32 100644
--- a/lib/go/thrift/binary_protocol.go
+++ b/lib/go/thrift/binary_protocol.go
@@ -32,22 +32,37 @@ import (
type TBinaryProtocol struct {
trans TRichTransport
origTransport TTransport
- strictRead bool
- strictWrite bool
+ cfg *TConfiguration
buffer [64]byte
}
type TBinaryProtocolFactory struct {
- strictRead bool
- strictWrite bool
+ cfg *TConfiguration
}
+// Deprecated: Use NewTBinaryProtocolConf instead.
func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
- return NewTBinaryProtocol(t, false, true)
+ return NewTBinaryProtocolConf(t, &TConfiguration{
+ noPropagation: true,
+ })
}
+// Deprecated: Use NewTBinaryProtocolConf instead.
func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
- p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
+ return NewTBinaryProtocolConf(t, &TConfiguration{
+ TBinaryStrictRead: &strictRead,
+ TBinaryStrictWrite: &strictWrite,
+
+ noPropagation: true,
+ })
+}
+
+func NewTBinaryProtocolConf(t TTransport, conf *TConfiguration) *TBinaryProtocol {
+ PropagateTConfiguration(t, conf)
+ p := &TBinaryProtocol{
+ origTransport: t,
+ cfg: conf,
+ }
if et, ok := t.(TRichTransport); ok {
p.trans = et
} else {
@@ -56,16 +71,35 @@ func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProt
return p
}
+// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
- return NewTBinaryProtocolFactory(false, true)
+ return NewTBinaryProtocolFactoryConf(&TConfiguration{
+ noPropagation: true,
+ })
}
+// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
- return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
+ return NewTBinaryProtocolFactoryConf(&TConfiguration{
+ TBinaryStrictRead: &strictRead,
+ TBinaryStrictWrite: &strictWrite,
+
+ noPropagation: true,
+ })
+}
+
+func NewTBinaryProtocolFactoryConf(conf *TConfiguration) *TBinaryProtocolFactory {
+ return &TBinaryProtocolFactory{
+ cfg: conf,
+ }
}
func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
- return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
+ return NewTBinaryProtocolConf(t, p.cfg)
+}
+
+func (p *TBinaryProtocolFactory) SetTConfiguration(conf *TConfiguration) {
+ p.cfg = conf
}
/**
@@ -73,7 +107,7 @@ func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
*/
func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
- if p.strictWrite {
+ if p.cfg.GetTBinaryStrictWrite() {
version := uint32(VERSION_1) | uint32(typeId)
e := p.WriteI32(ctx, int32(version))
if e != nil {
@@ -253,7 +287,7 @@ func (p *TBinaryProtocol) ReadMessageBegin(ctx context.Context) (name string, ty
}
return name, typeId, seqId, nil
}
- if p.strictRead {
+ if p.cfg.GetTBinaryStrictRead() {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
}
name, e2 := p.readStringBody(size)
@@ -428,6 +462,10 @@ func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err err
if e != nil {
return "", e
}
+ err = checkSizeForProtocol(size, p.cfg)
+ if err != nil {
+ return
+ }
if size < 0 {
err = invalidDataLength
return
@@ -450,8 +488,8 @@ func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
if e != nil {
return nil, e
}
- if size < 0 {
- return nil, invalidDataLength
+ if err := checkSizeForProtocol(size, p.cfg); err != nil {
+ return nil, err
}
buf, err := safeReadBytes(size, p.trans)
@@ -491,6 +529,17 @@ func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
return string(buf), NewTProtocolException(err)
}
+func (p *TBinaryProtocol) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.trans, conf)
+ PropagateTConfiguration(p.origTransport, conf)
+ p.cfg = conf
+}
+
+var (
+ _ TConfigurationSetter = (*TBinaryProtocolFactory)(nil)
+ _ TConfigurationSetter = (*TBinaryProtocol)(nil)
+)
+
// This function is shared between TBinaryProtocol and TCompactProtocol.
//
// It tries to read size bytes from trans, in a way that prevents large
diff --git a/lib/go/thrift/buffered_transport.go b/lib/go/thrift/buffered_transport.go
index 96702061b..aa551b4ab 100644
--- a/lib/go/thrift/buffered_transport.go
+++ b/lib/go/thrift/buffered_transport.go
@@ -90,3 +90,10 @@ func (p *TBufferedTransport) Flush(ctx context.Context) error {
func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) {
return p.tp.RemainingBytes()
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (p *TBufferedTransport) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.tp, conf)
+}
+
+var _ TConfigurationSetter = (*TBufferedTransport)(nil)
diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index 424906d61..25e6d0ccd 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -75,20 +75,37 @@ func init() {
}
}
-type TCompactProtocolFactory struct{}
+type TCompactProtocolFactory struct {
+ cfg *TConfiguration
+}
+// Deprecated: Use NewTCompactProtocolFactoryConf instead.
func NewTCompactProtocolFactory() *TCompactProtocolFactory {
- return &TCompactProtocolFactory{}
+ return NewTCompactProtocolFactoryConf(&TConfiguration{
+ noPropagation: true,
+ })
+}
+
+func NewTCompactProtocolFactoryConf(conf *TConfiguration) *TCompactProtocolFactory {
+ return &TCompactProtocolFactory{
+ cfg: conf,
+ }
}
func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol {
- return NewTCompactProtocol(trans)
+ return NewTCompactProtocolConf(trans, p.cfg)
+}
+
+func (p *TCompactProtocolFactory) SetTConfiguration(conf *TConfiguration) {
+ p.cfg = conf
}
type TCompactProtocol struct {
trans TRichTransport
origTransport TTransport
+ cfg *TConfiguration
+
// Used to keep track of the last field for the current and previous structs,
// so we can do the delta stuff.
lastField []int
@@ -107,9 +124,19 @@ type TCompactProtocol struct {
buffer [64]byte
}
-// Create a TCompactProtocol given a TTransport
+// Deprecated: Use NewTCompactProtocolConf instead.
func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
- p := &TCompactProtocol{origTransport: trans, lastField: []int{}}
+ return NewTCompactProtocolConf(trans, &TConfiguration{
+ noPropagation: true,
+ })
+}
+
+func NewTCompactProtocolConf(trans TTransport, conf *TConfiguration) *TCompactProtocol {
+ PropagateTConfiguration(trans, conf)
+ p := &TCompactProtocol{
+ origTransport: trans,
+ cfg: conf,
+ }
if et, ok := trans.(TRichTransport); ok {
p.trans = et
} else {
@@ -117,7 +144,6 @@ func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
}
return p
-
}
//
@@ -576,8 +602,9 @@ func (p *TCompactProtocol) ReadString(ctx context.Context) (value string, err er
if e != nil {
return "", NewTProtocolException(e)
}
- if length < 0 {
- return "", invalidDataLength
+ err = checkSizeForProtocol(length, p.cfg)
+ if err != nil {
+ return
}
if length == 0 {
return "", nil
@@ -599,12 +626,13 @@ func (p *TCompactProtocol) ReadBinary(ctx context.Context) (value []byte, err er
if e != nil {
return nil, NewTProtocolException(e)
}
+ err = checkSizeForProtocol(length, p.cfg)
+ if err != nil {
+ return
+ }
if length == 0 {
return []byte{}, nil
}
- if length < 0 {
- return nil, invalidDataLength
- }
buf, e := safeReadBytes(length, p.trans)
return buf, NewTProtocolException(e)
@@ -824,3 +852,14 @@ func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) {
func (p *TCompactProtocol) getCompactType(t TType) tCompactType {
return ttypeToCompactType[t]
}
+
+func (p *TCompactProtocol) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.trans, conf)
+ PropagateTConfiguration(p.origTransport, conf)
+ p.cfg = conf
+}
+
+var (
+ _ TConfigurationSetter = (*TCompactProtocolFactory)(nil)
+ _ TConfigurationSetter = (*TCompactProtocol)(nil)
+)
diff --git a/lib/go/thrift/configuration.go b/lib/go/thrift/configuration.go
new file mode 100644
index 000000000..454d9f377
--- /dev/null
+++ b/lib/go/thrift/configuration.go
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "crypto/tls"
+ "fmt"
+ "time"
+)
+
+// Default TConfiguration values.
+const (
+ DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024
+ DEFAULT_MAX_FRAME_SIZE = 16384000
+
+ DEFAULT_TBINARY_STRICT_READ = false
+ DEFAULT_TBINARY_STRICT_WRITE = true
+
+ DEFAULT_CONNECT_TIMEOUT = 0
+ DEFAULT_SOCKET_TIMEOUT = 0
+)
+
+// TConfiguration defines some configurations shared between TTransport,
+// TProtocol, TTransportFactory, TProtocolFactory, and other implementations.
+//
+// When constructing TConfiguration, you only need to specify the non-default
+// fields. All zero values have sane default values.
+//
+// Not all configurations defined are applicable to all implementations.
+// Implementations are free to ignore the configurations not applicable to them.
+//
+// All functions attached to this type are nil-safe.
+//
+// See [1] for spec.
+//
+// NOTE: When using TConfiguration, fill in all the configurations you want to
+// set across the stack, not only the ones you want to set in the immediate
+// TTransport/TProtocol.
+//
+// For example, say you want to migrate this old code into using TConfiguration:
+//
+// sccket := thrift.NewTSocketTimeout("host:port", time.Second)
+// transFactory := thrift.NewTFramedTransportFactoryMaxLength(
+// thrift.NewTTransportFactory(),
+// 1024 * 1024 * 256,
+// )
+// protoFactory := thrift.NewTBinaryProtocolFactory(true, true)
+//
+// This is the wrong way to do it because in the end the TConfiguration used by
+// socket and transFactory will be overwritten by the one used by protoFactory
+// because of TConfiguration propagation:
+//
+// // bad example, DO NOT USE
+// sccket := thrift.NewTSocketConf("host:port", &thrift.TConfiguration{
+// ConnectTimeout: time.Second,
+// SocketTimeout: time.Second,
+// })
+// transFactory := thrift.NewTFramedTransportFactoryConf(
+// thrift.NewTTransportFactory(),
+// &thrift.TConfiguration{
+// MaxFrameSize: 1024 * 1024 * 256,
+// },
+// )
+// protoFactory := thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{
+// TBinaryStrictRead: thrift.BoolPtr(true),
+// TBinaryStrictWrite: thrift.BoolPtr(true),
+// })
+//
+// This is the correct way to do it:
+//
+// conf := &thrift.TConfiguration{
+// ConnectTimeout: time.Second,
+// SocketTimeout: time.Second,
+//
+// MaxFrameSize: 1024 * 1024 * 256,
+//
+// TBinaryStrictRead: thrift.BoolPtr(true),
+// TBinaryStrictWrite: thrift.BoolPtr(true),
+// }
+// sccket := thrift.NewTSocketConf("host:port", conf)
+// transFactory := thrift.NewTFramedTransportFactoryConf(thrift.NewTTransportFactory(), conf)
+// protoFactory := thrift.NewTBinaryProtocolFactoryConf(conf)
+//
+// [1]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-tconfiguration.md
+type TConfiguration struct {
+ // If <= 0, DEFAULT_MAX_MESSAGE_SIZE will be used instead.
+ MaxMessageSize int32
+
+ // If <= 0, DEFAULT_MAX_FRAME_SIZE will be used instead.
+ //
+ // Also if MaxMessageSize < MaxFrameSize,
+ // MaxMessageSize will be used instead.
+ MaxFrameSize int32
+
+ // Connect and socket timeouts to be used by TSocket and TSSLSocket.
+ //
+ // 0 means no timeout.
+ //
+ // If <0, DEFAULT_CONNECT_TIMEOUT and DEFAULT_SOCKET_TIMEOUT will be
+ // used.
+ ConnectTimeout time.Duration
+ SocketTimeout time.Duration
+
+ // TLS config to be used by TSSLSocket.
+ TLSConfig *tls.Config
+
+ // Strict read/write configurations for TBinaryProtocol.
+ //
+ // BoolPtr helper function is available to use literal values.
+ TBinaryStrictRead *bool
+ TBinaryStrictWrite *bool
+
+ // The wrapped protocol id to be used in THeader transport/protocol.
+ //
+ // THeaderProtocolIDPtr and THeaderProtocolIDPtrMust helper functions
+ // are provided to help filling this value.
+ THeaderProtocolID *THeaderProtocolID
+
+ // Used internally by deprecated constructors, to avoid overriding
+ // underlying TTransport/TProtocol's cfg by accidental propagations.
+ //
+ // For external users this is always false.
+ noPropagation bool
+}
+
+// GetMaxMessageSize returns the max message size an implementation should
+// follow.
+//
+// It's nil-safe. DEFAULT_MAX_MESSAGE_SIZE will be returned if tc is nil.
+func (tc *TConfiguration) GetMaxMessageSize() int32 {
+ if tc == nil || tc.MaxMessageSize <= 0 {
+ return DEFAULT_MAX_MESSAGE_SIZE
+ }
+ return tc.MaxMessageSize
+}
+
+// GetMaxFrameSize returns the max frame size an implementation should follow.
+//
+// It's nil-safe. DEFAULT_MAX_FRAME_SIZE will be returned if tc is nil.
+//
+// If the configured max message size is smaller than the configured max frame
+// size, the smaller one will be returned instead.
+func (tc *TConfiguration) GetMaxFrameSize() int32 {
+ if tc == nil {
+ return DEFAULT_MAX_FRAME_SIZE
+ }
+ maxFrameSize := tc.MaxFrameSize
+ if maxFrameSize <= 0 {
+ maxFrameSize = DEFAULT_MAX_FRAME_SIZE
+ }
+ if maxMessageSize := tc.GetMaxMessageSize(); maxMessageSize < maxFrameSize {
+ return maxMessageSize
+ }
+ return maxFrameSize
+}
+
+// GetConnectTimeout returns the connect timeout should be used by TSocket and
+// TSSLSocket.
+//
+// It's nil-safe. If tc is nil, DEFAULT_CONNECT_TIMEOUT will be returned instead.
+func (tc *TConfiguration) GetConnectTimeout() time.Duration {
+ if tc == nil || tc.ConnectTimeout < 0 {
+ return DEFAULT_CONNECT_TIMEOUT
+ }
+ return tc.ConnectTimeout
+}
+
+// GetSocketTimeout returns the socket timeout should be used by TSocket and
+// TSSLSocket.
+//
+// It's nil-safe. If tc is nil, DEFAULT_SOCKET_TIMEOUT will be returned instead.
+func (tc *TConfiguration) GetSocketTimeout() time.Duration {
+ if tc == nil || tc.SocketTimeout < 0 {
+ return DEFAULT_SOCKET_TIMEOUT
+ }
+ return tc.SocketTimeout
+}
+
+// GetTLSConfig returns the tls config should be used by TSSLSocket.
+//
+// It's nil-safe. If tc is nil, nil will be returned instead.
+func (tc *TConfiguration) GetTLSConfig() *tls.Config {
+ if tc == nil {
+ return nil
+ }
+ return tc.TLSConfig
+}
+
+// GetTBinaryStrictRead returns the strict read configuration TBinaryProtocol
+// should follow.
+//
+// It's nil-safe. DEFAULT_TBINARY_STRICT_READ will be returned if either tc or
+// tc.TBinaryStrictRead is nil.
+func (tc *TConfiguration) GetTBinaryStrictRead() bool {
+ if tc == nil || tc.TBinaryStrictRead == nil {
+ return DEFAULT_TBINARY_STRICT_READ
+ }
+ return *tc.TBinaryStrictRead
+}
+
+// GetTBinaryStrictWrite returns the strict read configuration TBinaryProtocol
+// should follow.
+//
+// It's nil-safe. DEFAULT_TBINARY_STRICT_WRITE will be returned if either tc or
+// tc.TBinaryStrictWrite is nil.
+func (tc *TConfiguration) GetTBinaryStrictWrite() bool {
+ if tc == nil || tc.TBinaryStrictWrite == nil {
+ return DEFAULT_TBINARY_STRICT_WRITE
+ }
+ return *tc.TBinaryStrictWrite
+}
+
+// GetTHeaderProtocolID returns the THeaderProtocolID should be used by
+// THeaderProtocol clients (for servers, they always use the same one as the
+// client instead).
+//
+// It's nil-safe. If either tc or tc.THeaderProtocolID is nil,
+// THeaderProtocolDefault will be returned instead.
+// THeaderProtocolDefault will also be returned if configured value is invalid.
+func (tc *TConfiguration) GetTHeaderProtocolID() THeaderProtocolID {
+ if tc == nil || tc.THeaderProtocolID == nil {
+ return THeaderProtocolDefault
+ }
+ protoID := *tc.THeaderProtocolID
+ if err := protoID.Validate(); err != nil {
+ return THeaderProtocolDefault
+ }
+ return protoID
+}
+
+// THeaderProtocolIDPtr validates and returns the pointer to id.
+//
+// If id is not a valid THeaderProtocolID, a pointer to THeaderProtocolDefault
+// and the validation error will be returned.
+func THeaderProtocolIDPtr(id THeaderProtocolID) (*THeaderProtocolID, error) {
+ err := id.Validate()
+ if err != nil {
+ id = THeaderProtocolDefault
+ }
+ return &id, err
+}
+
+// THeaderProtocolIDPtrMust validates and returns the pointer to id.
+//
+// It's similar to THeaderProtocolIDPtr, but it panics on validation errors
+// instead of returning them.
+func THeaderProtocolIDPtrMust(id THeaderProtocolID) *THeaderProtocolID {
+ ptr, err := THeaderProtocolIDPtr(id)
+ if err != nil {
+ panic(err)
+ }
+ return ptr
+}
+
+// TConfigurationSetter is an optional interface TProtocol, TTransport,
+// TProtocolFactory, TTransportFactory, and other implementations can implement.
+//
+// It's intended to be called during intializations.
+// The behavior of calling SetTConfiguration on a TTransport/TProtocol in the
+// middle of a message is undefined:
+// It may or may not change the behavior of the current processing message,
+// and it may even cause the current message to fail.
+//
+// Note for implementations: SetTConfiguration might be called multiple times
+// with the same value in quick successions due to the implementation of the
+// propagation. Implementations should make SetTConfiguration as simple as
+// possible (usually just overwrite the stored configuration and propagate it to
+// the wrapped TTransports/TProtocols).
+type TConfigurationSetter interface {
+ SetTConfiguration(*TConfiguration)
+}
+
+// PropagateTConfiguration propagates cfg to impl if impl implements
+// TConfigurationSetter and cfg is non-nil, otherwise it does nothing.
+//
+// NOTE: nil cfg is not propagated. If you want to propagate a TConfiguration
+// with everything being default value, use &TConfiguration{} explicitly instead.
+func PropagateTConfiguration(impl interface{}, cfg *TConfiguration) {
+ if cfg == nil || cfg.noPropagation {
+ return
+ }
+
+ if setter, ok := impl.(TConfigurationSetter); ok {
+ setter.SetTConfiguration(cfg)
+ }
+}
+
+func checkSizeForProtocol(size int32, cfg *TConfiguration) error {
+ if size < 0 {
+ return NewTProtocolExceptionWithType(
+ NEGATIVE_SIZE,
+ fmt.Errorf("negative size: %d", size),
+ )
+ }
+ if size > cfg.GetMaxMessageSize() {
+ return NewTProtocolExceptionWithType(
+ SIZE_LIMIT,
+ fmt.Errorf("size exceeded max allowed: %d", size),
+ )
+ }
+ return nil
+}
+
+type tTransportFactoryConf struct {
+ delegate TTransportFactory
+ cfg *TConfiguration
+}
+
+func (f *tTransportFactoryConf) GetTransport(orig TTransport) (TTransport, error) {
+ trans, err := f.delegate.GetTransport(orig)
+ if err == nil {
+ PropagateTConfiguration(orig, f.cfg)
+ PropagateTConfiguration(trans, f.cfg)
+ }
+ return trans, err
+}
+
+func (f *tTransportFactoryConf) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(f.delegate, f.cfg)
+ f.cfg = cfg
+}
+
+// TTransportFactoryConf wraps a TTransportFactory to propagate
+// TConfiguration on the factory's GetTransport calls.
+func TTransportFactoryConf(delegate TTransportFactory, conf *TConfiguration) TTransportFactory {
+ return &tTransportFactoryConf{
+ delegate: delegate,
+ cfg: conf,
+ }
+}
+
+type tProtocolFactoryConf struct {
+ delegate TProtocolFactory
+ cfg *TConfiguration
+}
+
+func (f *tProtocolFactoryConf) GetProtocol(trans TTransport) TProtocol {
+ proto := f.delegate.GetProtocol(trans)
+ PropagateTConfiguration(trans, f.cfg)
+ PropagateTConfiguration(proto, f.cfg)
+ return proto
+}
+
+func (f *tProtocolFactoryConf) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(f.delegate, f.cfg)
+ f.cfg = cfg
+}
+
+// TProtocolFactoryConf wraps a TProtocolFactory to propagate
+// TConfiguration on the factory's GetProtocol calls.
+func TProtocolFactoryConf(delegate TProtocolFactory, conf *TConfiguration) TProtocolFactory {
+ return &tProtocolFactoryConf{
+ delegate: delegate,
+ cfg: conf,
+ }
+}
+
+var (
+ _ TConfigurationSetter = (*tTransportFactoryConf)(nil)
+ _ TConfigurationSetter = (*tProtocolFactoryConf)(nil)
+)
diff --git a/lib/go/thrift/configuration_test.go b/lib/go/thrift/configuration_test.go
new file mode 100644
index 000000000..f74784231
--- /dev/null
+++ b/lib/go/thrift/configuration_test.go
@@ -0,0 +1,338 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "crypto/tls"
+ "testing"
+ "time"
+)
+
+func TestTConfiguration(t *testing.T) {
+ invalidProtoID := THeaderProtocolID(-1)
+ if invalidProtoID.Validate() == nil {
+ t.Fatalf("Expected %v to be an invalid THeaderProtocolID, it passes the validation", invalidProtoID)
+ }
+
+ tlsConfig := &tls.Config{
+ Time: time.Now,
+ }
+
+ for _, c := range []struct {
+ label string
+ cfg *TConfiguration
+ expectedMessageSize int32
+ expectedFrameSize int32
+ expectedConnectTimeout time.Duration
+ expectedSocketTimeout time.Duration
+ expectedTLSConfig *tls.Config
+ expectedBinaryRead bool
+ expectedBinaryWrite bool
+ expectedProtoID THeaderProtocolID
+ }{
+ {
+ label: "nil",
+ cfg: nil,
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "empty",
+ cfg: &TConfiguration{},
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "normal",
+ cfg: &TConfiguration{
+ MaxMessageSize: 1024,
+ MaxFrameSize: 1024,
+ ConnectTimeout: time.Millisecond,
+ SocketTimeout: time.Millisecond * 2,
+ TLSConfig: tlsConfig,
+ TBinaryStrictRead: BoolPtr(true),
+ TBinaryStrictWrite: BoolPtr(false),
+ THeaderProtocolID: THeaderProtocolIDPtrMust(THeaderProtocolCompact),
+ },
+ expectedMessageSize: 1024,
+ expectedFrameSize: 1024,
+ expectedConnectTimeout: time.Millisecond,
+ expectedSocketTimeout: time.Millisecond * 2,
+ expectedTLSConfig: tlsConfig,
+ expectedBinaryRead: true,
+ expectedBinaryWrite: false,
+ expectedProtoID: THeaderProtocolCompact,
+ },
+ {
+ label: "message<frame",
+ cfg: &TConfiguration{
+ MaxMessageSize: 1024,
+ MaxFrameSize: 4096,
+ },
+ expectedMessageSize: 1024,
+ expectedFrameSize: 1024,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "frame<message",
+ cfg: &TConfiguration{
+ MaxMessageSize: 4096,
+ MaxFrameSize: 1024,
+ },
+ expectedMessageSize: 4096,
+ expectedFrameSize: 1024,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "negative-message-size",
+ cfg: &TConfiguration{
+ MaxMessageSize: -1,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "negative-frame-size",
+ cfg: &TConfiguration{
+ MaxFrameSize: -1,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "negative-connect-timeout",
+ cfg: &TConfiguration{
+ ConnectTimeout: -1,
+ SocketTimeout: time.Millisecond,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: time.Millisecond,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "negative-socket-timeout",
+ cfg: &TConfiguration{
+ SocketTimeout: -1,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "invalid-proto-id",
+ cfg: &TConfiguration{
+ THeaderProtocolID: &invalidProtoID,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ } {
+ t.Run(c.label, func(t *testing.T) {
+ t.Run("GetMaxMessageSize", func(t *testing.T) {
+ actual := c.cfg.GetMaxMessageSize()
+ if actual != c.expectedMessageSize {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedMessageSize,
+ actual,
+ )
+ }
+ })
+ t.Run("GetMaxFrameSize", func(t *testing.T) {
+ actual := c.cfg.GetMaxFrameSize()
+ if actual != c.expectedFrameSize {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedFrameSize,
+ actual,
+ )
+ }
+ })
+ t.Run("GetConnectTimeout", func(t *testing.T) {
+ actual := c.cfg.GetConnectTimeout()
+ if actual != c.expectedConnectTimeout {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedConnectTimeout,
+ actual,
+ )
+ }
+ })
+ t.Run("GetSocketTimeout", func(t *testing.T) {
+ actual := c.cfg.GetSocketTimeout()
+ if actual != c.expectedSocketTimeout {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedSocketTimeout,
+ actual,
+ )
+ }
+ })
+ t.Run("GetTLSConfig", func(t *testing.T) {
+ actual := c.cfg.GetTLSConfig()
+ if actual != c.expectedTLSConfig {
+ t.Errorf(
+ "Expected %p(%#v), got %p(%#v)",
+ c.expectedTLSConfig,
+ c.expectedTLSConfig,
+ actual,
+ actual,
+ )
+ }
+ })
+ t.Run("GetTBinaryStrictRead", func(t *testing.T) {
+ actual := c.cfg.GetTBinaryStrictRead()
+ if actual != c.expectedBinaryRead {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedBinaryRead,
+ actual,
+ )
+ }
+ })
+ t.Run("GetTBinaryStrictWrite", func(t *testing.T) {
+ actual := c.cfg.GetTBinaryStrictWrite()
+ if actual != c.expectedBinaryWrite {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedBinaryWrite,
+ actual,
+ )
+ }
+ })
+ t.Run("GetTHeaderProtocolID", func(t *testing.T) {
+ actual := c.cfg.GetTHeaderProtocolID()
+ if actual != c.expectedProtoID {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedProtoID,
+ actual,
+ )
+ }
+ })
+ })
+ }
+}
+
+func TestTHeaderProtocolIDPtr(t *testing.T) {
+ var invalidProtoID = THeaderProtocolID(-1)
+ if invalidProtoID.Validate() == nil {
+ t.Fatalf("Expected %v to be an invalid THeaderProtocolID, it passes the validation", invalidProtoID)
+ }
+
+ ptr, err := THeaderProtocolIDPtr(invalidProtoID)
+ if err == nil {
+ t.Error("Expected error on invalid proto id, got nil")
+ }
+ if ptr == nil {
+ t.Fatal("Expected non-nil pointer on invalid proto id, got nil")
+ }
+ if *ptr != THeaderProtocolDefault {
+ t.Errorf("Expected pointer to %v, got %v", THeaderProtocolDefault, *ptr)
+ }
+}
+
+func TestTHeaderProtocolIDPtrMust(t *testing.T) {
+ const expected = THeaderProtocolCompact
+ ptr := THeaderProtocolIDPtrMust(expected)
+ if *ptr != expected {
+ t.Errorf("Expected pointer to %v, got %v", expected, *ptr)
+ }
+}
+
+func TestTHeaderProtocolIDPtrMustPanic(t *testing.T) {
+ var invalidProtoID = THeaderProtocolID(-1)
+ if invalidProtoID.Validate() == nil {
+ t.Fatalf("Expected %v to be an invalid THeaderProtocolID, it passes the validation", invalidProtoID)
+ }
+
+ defer func() {
+ if recovered := recover(); recovered == nil {
+ t.Error("Expected panic on invalid proto id, did not happen.")
+ }
+ }()
+
+ THeaderProtocolIDPtrMust(invalidProtoID)
+}
+
+func TestPropagateTConfiguration(t *testing.T) {
+ cfg := &TConfiguration{}
+ // Just make sure it won't cause panics on some nil
+ // TProtocol/TTransport/TProtocolFactory/TTransportFactory values.
+ PropagateTConfiguration(nil, cfg)
+ var proto TProtocol
+ PropagateTConfiguration(proto, cfg)
+ var protoFactory TProtocolFactory
+ PropagateTConfiguration(protoFactory, cfg)
+ var trans TTransport
+ PropagateTConfiguration(trans, cfg)
+ var transFactory TTransportFactory
+ PropagateTConfiguration(transFactory, cfg)
+}
diff --git a/lib/go/thrift/debug_protocol.go b/lib/go/thrift/debug_protocol.go
index 875844b00..fdf9bfec1 100644
--- a/lib/go/thrift/debug_protocol.go
+++ b/lib/go/thrift/debug_protocol.go
@@ -437,3 +437,11 @@ func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) {
func (tdp *TDebugProtocol) Transport() TTransport {
return tdp.Delegate.Transport()
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (tdp *TDebugProtocol) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(tdp.Delegate, conf)
+ PropagateTConfiguration(tdp.DuplicateTo, conf)
+}
+
+var _ TConfigurationSetter = (*TDebugProtocol)(nil)
diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go
index f1920751a..f683e7f54 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -28,11 +28,13 @@ import (
"io"
)
+// Deprecated: Use DEFAULT_MAX_FRAME_SIZE instead.
const DEFAULT_MAX_LENGTH = 16384000
type TFramedTransport struct {
transport TTransport
- maxLength uint32
+
+ cfg *TConfiguration
writeBuf bytes.Buffer
@@ -43,32 +45,75 @@ type TFramedTransport struct {
}
type tFramedTransportFactory struct {
- factory TTransportFactory
- maxLength uint32
+ factory TTransportFactory
+ cfg *TConfiguration
}
+// Deprecated: Use NewTFramedTransportFactoryConf instead.
func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory {
- return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH}
+ return NewTFramedTransportFactoryConf(factory, &TConfiguration{
+ MaxFrameSize: DEFAULT_MAX_LENGTH,
+
+ noPropagation: true,
+ })
}
+// Deprecated: Use NewTFramedTransportFactoryConf instead.
func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory {
- return &tFramedTransportFactory{factory: factory, maxLength: maxLength}
+ return NewTFramedTransportFactoryConf(factory, &TConfiguration{
+ MaxFrameSize: int32(maxLength),
+
+ noPropagation: true,
+ })
+}
+
+func NewTFramedTransportFactoryConf(factory TTransportFactory, conf *TConfiguration) TTransportFactory {
+ PropagateTConfiguration(factory, conf)
+ return &tFramedTransportFactory{
+ factory: factory,
+ cfg: conf,
+ }
}
func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) {
+ PropagateTConfiguration(base, p.cfg)
tt, err := p.factory.GetTransport(base)
if err != nil {
return nil, err
}
- return NewTFramedTransportMaxLength(tt, p.maxLength), nil
+ return NewTFramedTransportConf(tt, p.cfg), nil
+}
+
+func (p *tFramedTransportFactory) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(p.factory, cfg)
+ p.cfg = cfg
}
+// Deprecated: Use NewTFramedTransportConf instead.
func NewTFramedTransport(transport TTransport) *TFramedTransport {
- return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH}
+ return NewTFramedTransportConf(transport, &TConfiguration{
+ MaxFrameSize: DEFAULT_MAX_LENGTH,
+
+ noPropagation: true,
+ })
}
+// Deprecated: Use NewTFramedTransportConf instead.
func NewTFramedTransportMaxLength(transport TTransport, maxLength uint32) *TFramedTransport {
- return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength}
+ return NewTFramedTransportConf(transport, &TConfiguration{
+ MaxFrameSize: int32(maxLength),
+
+ noPropagation: true,
+ })
+}
+
+func NewTFramedTransportConf(transport TTransport, conf *TConfiguration) *TFramedTransport {
+ PropagateTConfiguration(transport, conf)
+ return &TFramedTransport{
+ transport: transport,
+ reader: bufio.NewReader(transport),
+ cfg: conf,
+ }
}
func (p *TFramedTransport) Open() error {
@@ -155,7 +200,7 @@ func (p *TFramedTransport) readFrame() error {
return err
}
size := binary.BigEndian.Uint32(buf)
- if size < 0 || size > p.maxLength {
+ if size < 0 || size > uint32(p.cfg.GetMaxFrameSize()) {
return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
}
_, err := io.CopyN(&p.readBuf, p.reader, int64(size))
@@ -165,3 +210,14 @@ func (p *TFramedTransport) readFrame() error {
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
return uint64(p.readBuf.Len())
}
+
+// SetTConfiguration implements TConfigurationSetter.
+func (p *TFramedTransport) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(p.transport, cfg)
+ p.cfg = cfg
+}
+
+var (
+ _ TConfigurationSetter = (*tFramedTransportFactory)(nil)
+ _ TConfigurationSetter = (*TFramedTransport)(nil)
+)
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index f86d558aa..5ad48e43b 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -34,76 +34,65 @@ type THeaderProtocol struct {
// Will be initialized on first read/write.
protocol TProtocol
+
+ cfg *TConfiguration
+}
+
+// Deprecated: Use NewTHeaderProtocolConf instead.
+func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
+ return newTHeaderProtocolConf(trans, &TConfiguration{
+ noPropagation: true,
+ })
}
-// NewTHeaderProtocol creates a new THeaderProtocol from the underlying
-// transport with default protocol ID.
+// NewTHeaderProtocolConf creates a new THeaderProtocol from the underlying
+// transport with given TConfiguration.
//
// The passed in transport will be wrapped with THeaderTransport.
//
// Note that THeaderTransport handles frame and zlib by itself,
// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
-func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
- p, err := newTHeaderProtocolWithProtocolID(trans, THeaderProtocolDefault)
- if err != nil {
- // Since we used THeaderProtocolDefault this should never happen,
- // but put a sanity check here just in case.
- panic(err)
- }
- return p
+func NewTHeaderProtocolConf(trans TTransport, conf *TConfiguration) *THeaderProtocol {
+ return newTHeaderProtocolConf(trans, conf)
}
-func newTHeaderProtocolWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderProtocol, error) {
- t, err := NewTHeaderTransportWithProtocolID(trans, protoID)
- if err != nil {
- return nil, err
- }
- p, err := t.protocolID.GetProtocol(t)
- if err != nil {
- return nil, err
- }
+func newTHeaderProtocolConf(trans TTransport, cfg *TConfiguration) *THeaderProtocol {
+ t := NewTHeaderTransportConf(trans, cfg)
+ p, _ := t.cfg.GetTHeaderProtocolID().GetProtocol(t)
+ PropagateTConfiguration(p, cfg)
return &THeaderProtocol{
transport: t,
protocol: p,
- }, nil
+ cfg: cfg,
+ }
}
type tHeaderProtocolFactory struct {
- protoID THeaderProtocolID
+ cfg *TConfiguration
}
func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
- p, err := newTHeaderProtocolWithProtocolID(trans, f.protoID)
- if err != nil {
- // Currently there's no way for external users to construct a
- // valid factory with invalid protoID, so this should never
- // happen. But put a sanity check here just in case in the
- // future a bug made that possible.
- panic(err)
- }
- return p
+ return newTHeaderProtocolConf(trans, f.cfg)
}
-// NewTHeaderProtocolFactory creates a factory for THeader with default protocol
-// ID.
-//
-// It's a wrapper for NewTHeaderProtocol
+func (f *tHeaderProtocolFactory) SetTConfiguration(cfg *TConfiguration) {
+ f.cfg = cfg
+}
+
+// Deprecated: Use NewTHeaderProtocolFactoryConf instead.
func NewTHeaderProtocolFactory() TProtocolFactory {
- return tHeaderProtocolFactory{
- protoID: THeaderProtocolDefault,
- }
+ return NewTHeaderProtocolFactoryConf(&TConfiguration{
+ noPropagation: true,
+ })
}
-// NewTHeaderProtocolFactoryWithProtocolID creates a factory for THeader with
-// given protocol ID.
-func NewTHeaderProtocolFactoryWithProtocolID(protoID THeaderProtocolID) (TProtocolFactory, error) {
- if err := protoID.Validate(); err != nil {
- return nil, err
- }
+// NewTHeaderProtocolFactoryConf creates a factory for THeader with given
+// TConfiguration.
+func NewTHeaderProtocolFactoryConf(conf *TConfiguration) TProtocolFactory {
return tHeaderProtocolFactory{
- protoID: protoID,
- }, nil
+ cfg: conf,
+ }
}
// Transport returns the underlying transport.
@@ -142,6 +131,7 @@ func (p *THeaderProtocol) WriteMessageBegin(ctx context.Context, name string, ty
if err != nil {
return err
}
+ PropagateTConfiguration(newProto, p.cfg)
p.protocol = newProto
p.transport.SequenceID = seqID
return p.protocol.WriteMessageBegin(ctx, name, typeID, seqID)
@@ -261,6 +251,7 @@ func (p *THeaderProtocol) ReadMessageBegin(ctx context.Context) (name string, ty
}
return
}
+ PropagateTConfiguration(newProto, p.cfg)
p.protocol = newProto
return p.protocol.ReadMessageBegin(ctx)
@@ -346,6 +337,13 @@ func (p *THeaderProtocol) Skip(ctx context.Context, fieldType TType) error {
return p.protocol.Skip(ctx, fieldType)
}
+// SetTConfiguration implements TConfigurationSetter.
+func (p *THeaderProtocol) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(p.transport, cfg)
+ PropagateTConfiguration(p.protocol, cfg)
+ p.cfg = cfg
+}
+
// GetResponseHeadersFromClient is a helper function to get the read THeaderMap
// from the last response received from the given client.
//
@@ -359,3 +357,8 @@ func GetResponseHeadersFromClient(c TClient) THeaderMap {
}
return nil
}
+
+var (
+ _ TConfigurationSetter = (*tHeaderProtocolFactory)(nil)
+ _ TConfigurationSetter = (*THeaderProtocol)(nil)
+)
diff --git a/lib/go/thrift/header_protocol_test.go b/lib/go/thrift/header_protocol_test.go
index f66ea6463..48a69bf23 100644
--- a/lib/go/thrift/header_protocol_test.go
+++ b/lib/go/thrift/header_protocol_test.go
@@ -34,11 +34,9 @@ func TestReadWriteHeaderProtocol(t *testing.T) {
t.Run(
"compact",
func(t *testing.T) {
- f, err := NewTHeaderProtocolFactoryWithProtocolID(THeaderProtocolCompact)
- if err != nil {
- t.Fatal(err)
- }
- ReadWriteProtocolTest(t, f)
+ ReadWriteProtocolTest(t, NewTHeaderProtocolFactoryConf(&TConfiguration{
+ THeaderProtocolID: THeaderProtocolIDPtrMust(THeaderProtocolCompact),
+ }))
},
)
}
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index 562d02fa4..1e8e30244 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -264,7 +264,7 @@ type THeaderTransport struct {
writeTransforms []THeaderTransformID
clientType clientType
- protocolID THeaderProtocolID
+ cfg *TConfiguration
// buffer is used in the following scenarios to avoid repetitive
// allocations, while 4 is big enough for all those scenarios:
@@ -276,51 +276,35 @@ type THeaderTransport struct {
var _ TTransport = (*THeaderTransport)(nil)
-// NewTHeaderTransport creates THeaderTransport from the underlying transport.
-//
-// Please note that THeaderTransport handles framing and zlib by itself,
-// so the underlying transport should be the raw socket transports (TSocket or TSSLSocket),
-// instead of rich transports like TZlibTransport or TFramedTransport.
-//
-// If trans is already a *THeaderTransport, it will be returned as is.
+// Deprecated: Use NewTHeaderTransportConf instead.
func NewTHeaderTransport(trans TTransport) *THeaderTransport {
- if ht, ok := trans.(*THeaderTransport); ok {
- return ht
- }
- return &THeaderTransport{
- transport: trans,
- reader: bufio.NewReader(trans),
- writeHeaders: make(THeaderMap),
- protocolID: THeaderProtocolDefault,
- }
+ return NewTHeaderTransportConf(trans, &TConfiguration{
+ noPropagation: true,
+ })
}
-// NewTHeaderTransportWithProtocolID creates THeaderTransport from the
-// underlying transport, with given protocol ID set.
+// NewTHeaderTransportConf creates THeaderTransport from the
+// underlying transport, with given TConfiguration attached.
//
// If trans is already a *THeaderTransport, it will be returned as is,
-// but with protocol ID overridden by the value passed in.
-//
-// If the passed in protocol ID is an invalid/unsupported one,
-// this function returns error.
+// but with TConfiguration overridden by the value passed in.
//
-// The protocol ID overridden is only useful for client transports.
+// The protocol ID in TConfiguration is only useful for client transports.
// For servers,
// the protocol ID will be overridden again to the one set by the client,
// to ensure that servers always speak the same dialect as the client.
-func NewTHeaderTransportWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderTransport, error) {
- if err := protoID.Validate(); err != nil {
- return nil, err
- }
+func NewTHeaderTransportConf(trans TTransport, conf *TConfiguration) *THeaderTransport {
if ht, ok := trans.(*THeaderTransport); ok {
- return ht, nil
+ ht.SetTConfiguration(conf)
+ return ht
}
+ PropagateTConfiguration(trans, conf)
return &THeaderTransport{
transport: trans,
reader: bufio.NewReader(trans),
writeHeaders: make(THeaderMap),
- protocolID: protoID,
- }, nil
+ cfg: conf,
+ }
}
// Open calls the underlying transport's Open function.
@@ -375,7 +359,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
// At this point it should be a framed message,
// sanity check on frameSize then discard the peeked part.
- if frameSize > THeaderMaxFrameSize {
+ if frameSize > THeaderMaxFrameSize || frameSize > uint32(t.cfg.GetMaxFrameSize()) {
return NewTProtocolExceptionWithType(
SIZE_LIMIT,
errors.New("frame too large"),
@@ -451,6 +435,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
return err
}
hp := NewTCompactProtocol(headerBuf)
+ hp.SetTConfiguration(t.cfg)
// At this point the header is already read into headerBuf,
// and t.frameBuffer starts from the actual payload.
@@ -458,7 +443,17 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
if err != nil {
return err
}
- t.protocolID = THeaderProtocolID(protoID)
+ idPtr, err := THeaderProtocolIDPtr(THeaderProtocolID(protoID))
+ if err != nil {
+ return err
+ }
+ if t.cfg == nil {
+ t.cfg = &TConfiguration{
+ noPropagation: true,
+ }
+ }
+ t.cfg.THeaderProtocolID = idPtr
+
var transformCount int32
transformCount, err = hp.readVarint32()
if err != nil {
@@ -601,7 +596,8 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
case clientHeaders:
headers := NewTMemoryBuffer()
hp := NewTCompactProtocol(headers)
- if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil {
+ hp.SetTConfiguration(t.cfg)
+ if _, err := hp.writeVarint32(int32(t.cfg.GetTHeaderProtocolID())); err != nil {
return NewTTransportExceptionFromError(err)
}
if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); err != nil {
@@ -746,7 +742,7 @@ func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error {
func (t *THeaderTransport) Protocol() THeaderProtocolID {
switch t.clientType {
default:
- return t.protocolID
+ return t.cfg.GetTHeaderProtocolID()
case clientFramedBinary, clientUnframedBinary:
return THeaderProtocolBinary
case clientFramedCompact, clientUnframedCompact:
@@ -763,17 +759,37 @@ func (t *THeaderTransport) isFramed() bool {
}
}
+// SetTConfiguration implements TConfigurationSetter.
+func (t *THeaderTransport) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(t.transport, cfg)
+ t.cfg = cfg
+}
+
// THeaderTransportFactory is a TTransportFactory implementation to create
// THeaderTransport.
+//
+// It also implements TConfigurationSetter.
type THeaderTransportFactory struct {
// The underlying factory, could be nil.
Factory TTransportFactory
+
+ cfg *TConfiguration
}
-// NewTHeaderTransportFactory creates a new *THeaderTransportFactory.
+// Deprecated: Use NewTHeaderTransportFactoryConf instead.
func NewTHeaderTransportFactory(factory TTransportFactory) TTransportFactory {
+ return NewTHeaderTransportFactoryConf(factory, &TConfiguration{
+ noPropagation: true,
+ })
+}
+
+// NewTHeaderTransportFactoryConf creates a new *THeaderTransportFactory with
+// the given *TConfiguration.
+func NewTHeaderTransportFactoryConf(factory TTransportFactory, conf *TConfiguration) TTransportFactory {
return &THeaderTransportFactory{
Factory: factory,
+
+ cfg: conf,
}
}
@@ -784,7 +800,18 @@ func (f *THeaderTransportFactory) GetTransport(trans TTransport) (TTransport, er
if err != nil {
return nil, err
}
- return NewTHeaderTransport(t), nil
+ return NewTHeaderTransportConf(t, f.cfg), nil
}
- return NewTHeaderTransport(trans), nil
+ return NewTHeaderTransportConf(trans, f.cfg), nil
}
+
+// SetTConfiguration implements TConfigurationSetter.
+func (f *THeaderTransportFactory) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(f.Factory, f.cfg)
+ f.cfg = cfg
+}
+
+var (
+ _ TConfigurationSetter = (*THeaderTransportFactory)(nil)
+ _ TConfigurationSetter = (*THeaderTransport)(nil)
+)
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 5b47680e8..41efb1898 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -21,6 +21,7 @@ package thrift
import (
"context"
+ "fmt"
"io"
"io/ioutil"
"strings"
@@ -31,10 +32,9 @@ import (
func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocolID) {
trans := NewTMemoryBuffer()
reader := NewTHeaderTransport(trans)
- writer, err := NewTHeaderTransportWithProtocolID(trans, protoID)
- if err != nil {
- t.Fatal(err)
- }
+ writer := NewTHeaderTransportConf(trans, &TConfiguration{
+ THeaderProtocolID: &protoID,
+ })
const key1 = "key1"
const value1 = "value1"
@@ -265,3 +265,19 @@ func TestTHeaderTransportEndOfFrameHandling(t *testing.T) {
t.Error(err)
}
}
+
+func BenchmarkTHeaderProtocolIDValidate(b *testing.B) {
+ for _, c := range []THeaderProtocolID{
+ THeaderProtocolBinary,
+ THeaderProtocolCompact,
+ -1,
+ } {
+ b.Run(fmt.Sprintf("%2v", c), func(b *testing.B) {
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ c.Validate()
+ }
+ })
+ })
+ }
+}
diff --git a/lib/go/thrift/iostream_transport.go b/lib/go/thrift/iostream_transport.go
index 0b1775d06..1c477990f 100644
--- a/lib/go/thrift/iostream_transport.go
+++ b/lib/go/thrift/iostream_transport.go
@@ -212,3 +212,11 @@ func (p *StreamTransport) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the truth is, we just don't know unless framed is used
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (p *StreamTransport) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.Reader, conf)
+ PropagateTConfiguration(p.Writer, conf)
+}
+
+var _ TConfigurationSetter = (*StreamTransport)(nil)
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index edc49cc10..8e59d16cf 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -587,3 +587,5 @@ func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) {
e := fmt.Errorf("Unknown type identifier: %s", fieldType)
return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e)
}
+
+var _ TConfigurationSetter = (*TJSONProtocol)(nil)
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index e94b44bbb..d1a815453 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -1364,3 +1364,10 @@ func (p *TSimpleJSONProtocol) write(b []byte) (int, error) {
}
return n, err
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (p *TSimpleJSONProtocol) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.trans, conf)
+}
+
+var _ TConfigurationSetter = (*TSimpleJSONProtocol)(nil)
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index af75dd1a9..e911bf166 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -26,57 +26,116 @@ import (
)
type TSocket struct {
- conn *socketConn
- addr net.Addr
+ conn *socketConn
+ addr net.Addr
+ cfg *TConfiguration
+
connectTimeout time.Duration
socketTimeout time.Duration
}
-// NewTSocket creates a net.Conn-backed TTransport, given a host and port
-//
-// Example:
-// trans, err := thrift.NewTSocket("localhost:9090")
+// Deprecated: Use NewTSocketConf instead.
func NewTSocket(hostPort string) (*TSocket, error) {
- return NewTSocketTimeout(hostPort, 0, 0)
+ return NewTSocketConf(hostPort, &TConfiguration{
+ noPropagation: true,
+ })
}
-// NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and port
-// it also accepts a timeout as a time.Duration
-func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout time.Duration) (*TSocket, error) {
- //conn, err := net.DialTimeout(network, address, timeout)
+// NewTSocketConf creates a net.Conn-backed TTransport, given a host and port.
+//
+// Example:
+//
+// trans, err := thrift.NewTSocketConf("localhost:9090", &TConfiguration{
+// ConnectTimeout: time.Second, // Use 0 for no timeout
+// SocketTimeout: time.Second, // Use 0 for no timeout
+// })
+func NewTSocketConf(hostPort string, conf *TConfiguration) (*TSocket, error) {
addr, err := net.ResolveTCPAddr("tcp", hostPort)
if err != nil {
return nil, err
}
- return NewTSocketFromAddrTimeout(addr, connTimeout, soTimeout), nil
+ return NewTSocketFromAddrConf(addr, conf), nil
+}
+
+// Deprecated: Use NewTSocketConf instead.
+func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout time.Duration) (*TSocket, error) {
+ return NewTSocketConf(hostPort, &TConfiguration{
+ ConnectTimeout: connTimeout,
+ SocketTimeout: soTimeout,
+
+ noPropagation: true,
+ })
+}
+
+// NewTSocketFromAddrConf creates a TSocket from a net.Addr
+func NewTSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSocket {
+ return &TSocket{
+ addr: addr,
+ cfg: conf,
+ }
}
-// Creates a TSocket from a net.Addr
+// Deprecated: Use NewTSocketFromAddrConf instead.
func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, soTimeout time.Duration) *TSocket {
- return &TSocket{addr: addr, connectTimeout: connTimeout, socketTimeout: soTimeout}
+ return NewTSocketFromAddrConf(addr, &TConfiguration{
+ ConnectTimeout: connTimeout,
+ SocketTimeout: soTimeout,
+
+ noPropagation: true,
+ })
+}
+
+// NewTSocketFromConnConf creates a TSocket from an existing net.Conn.
+func NewTSocketFromConnConf(conn net.Conn, conf *TConfiguration) *TSocket {
+ return &TSocket{
+ conn: wrapSocketConn(conn),
+ addr: conn.RemoteAddr(),
+ cfg: conf,
+ }
}
-// Creates a TSocket from an existing net.Conn
+// Deprecated: Use NewTSocketFromConnConf instead.
func NewTSocketFromConnTimeout(conn net.Conn, socketTimeout time.Duration) *TSocket {
- return &TSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), socketTimeout: socketTimeout}
+ return NewTSocketFromConnConf(conn, &TConfiguration{
+ SocketTimeout: socketTimeout,
+
+ noPropagation: true,
+ })
+}
+
+// SetTConfiguration implements TConfigurationSetter.
+//
+// It can be used to set connect and socket timeouts.
+func (p *TSocket) SetTConfiguration(conf *TConfiguration) {
+ p.cfg = conf
}
// Sets the connect timeout
func (p *TSocket) SetConnTimeout(timeout time.Duration) error {
- p.connectTimeout = timeout
+ if p.cfg == nil {
+ p.cfg = &TConfiguration{
+ noPropagation: true,
+ }
+ }
+ p.cfg.ConnectTimeout = timeout
return nil
}
// Sets the socket timeout
func (p *TSocket) SetSocketTimeout(timeout time.Duration) error {
- p.socketTimeout = timeout
+ if p.cfg == nil {
+ p.cfg = &TConfiguration{
+ noPropagation: true,
+ }
+ }
+ p.cfg.SocketTimeout = timeout
return nil
}
func (p *TSocket) pushDeadline(read, write bool) {
var t time.Time
- if p.socketTimeout > 0 {
- t = time.Now().Add(time.Duration(p.socketTimeout))
+ if timeout := p.cfg.GetSocketTimeout(); timeout > 0 {
+ t = time.Now().Add(time.Duration(timeout))
}
if read && write {
p.conn.SetDeadline(t)
@@ -105,7 +164,7 @@ func (p *TSocket) Open() error {
if p.conn, err = createSocketConnFromReturn(net.DialTimeout(
p.addr.Network(),
p.addr.String(),
- p.connectTimeout,
+ p.cfg.GetConnectTimeout(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
@@ -175,3 +234,5 @@ func (p *TSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the truth is, we just don't know unless framed is used
}
+
+var _ TConfigurationSetter = (*TSocket)(nil)
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index 15ae96f60..6359a74ce 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -34,70 +34,115 @@ type TSSLSocket struct {
// addr is nil when hostPort is not "", and is only used when the
// TSSLSocket is constructed from a net.Addr.
addr net.Addr
- cfg *tls.Config
- connectTimeout time.Duration
- socketTimeout time.Duration
+ cfg *TConfiguration
}
-// NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration
+// NewTSSLSocketConf creates a net.Conn-backed TTransport, given a host and port.
//
// Example:
-// trans, err := thrift.NewTSSLSocket("localhost:9090", nil)
+//
+// trans, err := thrift.NewTSSLSocketConf("localhost:9090", nil, &TConfiguration{
+// ConnectTimeout: time.Second, // Use 0 for no timeout
+// SocketTimeout: time.Second, // Use 0 for no timeout
+// })
+func NewTSSLSocketConf(hostPort string, conf *TConfiguration) (*TSSLSocket, error) {
+ if cfg := conf.GetTLSConfig(); cfg != nil && cfg.MinVersion == 0 {
+ cfg.MinVersion = tls.VersionTLS10
+ }
+ return &TSSLSocket{
+ hostPort: hostPort,
+ cfg: conf,
+ }, nil
+}
+
+// Deprecated: Use NewTSSLSocketConf instead.
func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
- return NewTSSLSocketTimeout(hostPort, cfg, 0, 0)
+ return NewTSSLSocketConf(hostPort, &TConfiguration{
+ TLSConfig: cfg,
+
+ noPropagation: true,
+ })
}
-// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port
-// it also accepts a tls Configuration and connect/socket timeouts as time.Duration
+// Deprecated: Use NewTSSLSocketConf instead.
func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, connectTimeout, socketTimeout time.Duration) (*TSSLSocket, error) {
- if cfg.MinVersion == 0 {
- cfg.MinVersion = tls.VersionTLS10
- }
+ return NewTSSLSocketConf(hostPort, &TConfiguration{
+ ConnectTimeout: connectTimeout,
+ SocketTimeout: socketTimeout,
+ TLSConfig: cfg,
+
+ noPropagation: true,
+ })
+}
+
+// NewTSSLSocketFromAddrConf creates a TSSLSocket from a net.Addr.
+func NewTSSLSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSSLSocket {
return &TSSLSocket{
- hostPort: hostPort,
- cfg: cfg,
- connectTimeout: connectTimeout,
- socketTimeout: socketTimeout,
- }, nil
+ addr: addr,
+ cfg: conf,
+ }
}
-// Creates a TSSLSocket from a net.Addr
+// Deprecated: Use NewTSSLSocketFromAddrConf instead.
func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, connectTimeout, socketTimeout time.Duration) *TSSLSocket {
+ return NewTSSLSocketFromAddrConf(addr, &TConfiguration{
+ ConnectTimeout: connectTimeout,
+ SocketTimeout: socketTimeout,
+ TLSConfig: cfg,
+
+ noPropagation: true,
+ })
+}
+
+// NewTSSLSocketFromConnConf creates a TSSLSocket from an existing net.Conn.
+func NewTSSLSocketFromConnConf(conn net.Conn, conf *TConfiguration) *TSSLSocket {
return &TSSLSocket{
- addr: addr,
- cfg: cfg,
- connectTimeout: connectTimeout,
- socketTimeout: socketTimeout,
+ conn: wrapSocketConn(conn),
+ addr: conn.RemoteAddr(),
+ cfg: conf,
}
}
-// Creates a TSSLSocket from an existing net.Conn
+// Deprecated: Use NewTSSLSocketFromConnConf instead.
func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, socketTimeout time.Duration) *TSSLSocket {
- return &TSSLSocket{
- conn: wrapSocketConn(conn),
- addr: conn.RemoteAddr(),
- cfg: cfg,
- socketTimeout: socketTimeout,
- }
+ return NewTSSLSocketFromConnConf(conn, &TConfiguration{
+ SocketTimeout: socketTimeout,
+ TLSConfig: cfg,
+
+ noPropagation: true,
+ })
+}
+
+// SetTConfiguration implements TConfigurationSetter.
+//
+// It can be used to change connect and socket timeouts.
+func (p *TSSLSocket) SetTConfiguration(conf *TConfiguration) {
+ p.cfg = conf
}
// Sets the connect timeout
func (p *TSSLSocket) SetConnTimeout(timeout time.Duration) error {
- p.connectTimeout = timeout
+ if p.cfg == nil {
+ p.cfg = &TConfiguration{}
+ }
+ p.cfg.ConnectTimeout = timeout
return nil
}
// Sets the socket timeout
func (p *TSSLSocket) SetSocketTimeout(timeout time.Duration) error {
- p.socketTimeout = timeout
+ if p.cfg == nil {
+ p.cfg = &TConfiguration{}
+ }
+ p.cfg.SocketTimeout = timeout
return nil
}
func (p *TSSLSocket) pushDeadline(read, write bool) {
var t time.Time
- if p.socketTimeout > 0 {
- t = time.Now().Add(time.Duration(p.socketTimeout))
+ if timeout := p.cfg.GetSocketTimeout(); timeout > 0 {
+ t = time.Now().Add(time.Duration(timeout))
}
if read && write {
p.conn.SetDeadline(t)
@@ -116,11 +161,11 @@ func (p *TSSLSocket) Open() error {
if p.hostPort != "" {
if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
&net.Dialer{
- Timeout: p.connectTimeout,
+ Timeout: p.cfg.GetConnectTimeout(),
},
"tcp",
p.hostPort,
- p.cfg,
+ p.cfg.GetTLSConfig(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
@@ -139,11 +184,11 @@ func (p *TSSLSocket) Open() error {
}
if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
&net.Dialer{
- Timeout: p.connectTimeout,
+ Timeout: p.cfg.GetConnectTimeout(),
},
p.addr.Network(),
p.addr.String(),
- p.cfg,
+ p.cfg.GetTLSConfig(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
@@ -209,3 +254,5 @@ func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the truth is, we just don't know unless framed is used
}
+
+var _ TConfigurationSetter = (*TSSLSocket)(nil)
diff --git a/lib/go/thrift/zlib_transport.go b/lib/go/thrift/zlib_transport.go
index e7efdfb9e..259943a62 100644
--- a/lib/go/thrift/zlib_transport.go
+++ b/lib/go/thrift/zlib_transport.go
@@ -128,3 +128,10 @@ func (z *TZlibTransport) RemainingBytes() uint64 {
func (z *TZlibTransport) Write(p []byte) (int, error) {
return z.writer.Write(p)
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (z *TZlibTransport) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(z.transport, conf)
+}
+
+var _ TConfigurationSetter = (*TZlibTransport)(nil)