summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzeshuai007 <51382517@qq.com>2020-06-15 17:00:33 +0800
committerJens Geyer <jensg@apache.org>2020-07-25 12:13:53 +0200
commit86352b4821085d63861deab59c46ef1042fbfe81 (patch)
tree6c9c441d4125e4bb115e9989a769c99b36212677
parent23c8e52fa0708c53f74958944ecf04b293d1db73 (diff)
downloadthrift-86352b4821085d63861deab59c46ef1042fbfe81.tar.gz
THRIFT-5237 Implement MAX_MESSAGE_SIZE and consolidate limits into a TConfiguration class
Client: cpp Patch: Zezeng Wang This closes #2185
-rwxr-xr-xlib/cpp/Makefile.am7
-rw-r--r--lib/cpp/src/thrift/TConfiguration.h55
-rw-r--r--lib/cpp/src/thrift/protocol/TBinaryProtocol.h18
-rw-r--r--lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc37
-rw-r--r--lib/cpp/src/thrift/protocol/TCompactProtocol.h18
-rw-r--r--lib/cpp/src/thrift/protocol/TCompactProtocol.tcc32
-rw-r--r--lib/cpp/src/thrift/protocol/TEnum.h66
-rw-r--r--lib/cpp/src/thrift/protocol/TJSONProtocol.cpp35
-rw-r--r--lib/cpp/src/thrift/protocol/TJSONProtocol.h18
-rw-r--r--lib/cpp/src/thrift/protocol/TList.h55
-rw-r--r--lib/cpp/src/thrift/protocol/TMap.h59
-rw-r--r--lib/cpp/src/thrift/protocol/TProtocol.h68
-rw-r--r--lib/cpp/src/thrift/protocol/TSet.h61
-rw-r--r--lib/cpp/src/thrift/transport/TBufferTransports.cpp2
-rw-r--r--lib/cpp/src/thrift/transport/TBufferTransports.h54
-rw-r--r--lib/cpp/src/thrift/transport/TFDTransport.cpp1
-rw-r--r--lib/cpp/src/thrift/transport/TFDTransport.h6
-rw-r--r--lib/cpp/src/thrift/transport/TFileTransport.cpp8
-rw-r--r--lib/cpp/src/thrift/transport/TFileTransport.h2
-rw-r--r--lib/cpp/src/thrift/transport/THeaderTransport.cpp1
-rw-r--r--lib/cpp/src/thrift/transport/THeaderTransport.h10
-rw-r--r--lib/cpp/src/thrift/transport/THttpClient.cpp13
-rw-r--r--lib/cpp/src/thrift/transport/THttpClient.h7
-rw-r--r--lib/cpp/src/thrift/transport/THttpServer.cpp5
-rw-r--r--lib/cpp/src/thrift/transport/THttpServer.h2
-rw-r--r--lib/cpp/src/thrift/transport/THttpTransport.cpp6
-rw-r--r--lib/cpp/src/thrift/transport/THttpTransport.h6
-rw-r--r--lib/cpp/src/thrift/transport/TPipe.cpp24
-rw-r--r--lib/cpp/src/thrift/transport/TPipe.h12
-rw-r--r--lib/cpp/src/thrift/transport/TSSLSocket.cpp29
-rw-r--r--lib/cpp/src/thrift/transport/TSSLSocket.h15
-rw-r--r--lib/cpp/src/thrift/transport/TShortReadTransport.h17
-rw-r--r--lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp4
-rw-r--r--lib/cpp/src/thrift/transport/TSimpleFileTransport.h3
-rw-r--r--lib/cpp/src/thrift/transport/TSocket.cpp27
-rw-r--r--lib/cpp/src/thrift/transport/TSocket.h11
-rw-r--r--lib/cpp/src/thrift/transport/TTransport.h90
-rw-r--r--lib/cpp/src/thrift/transport/TTransportUtils.cpp7
-rw-r--r--lib/cpp/src/thrift/transport/TTransportUtils.h15
-rw-r--r--lib/cpp/src/thrift/transport/TVirtualTransport.h4
-rw-r--r--lib/cpp/src/thrift/transport/TWebSocketServer.h5
-rw-r--r--lib/cpp/src/thrift/transport/TZlibTransport.cpp3
-rw-r--r--lib/cpp/src/thrift/transport/TZlibTransport.h6
-rw-r--r--lib/cpp/test/CMakeLists.txt1
-rwxr-xr-xlib/cpp/test/Makefile.am3
-rw-r--r--lib/cpp/test/ThrifttReadCheckTests.cpp227
46 files changed, 1004 insertions, 151 deletions
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index a536d1719..3a0c4e63b 100755
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -141,7 +141,8 @@ include_thrift_HEADERS = \
src/thrift/TApplicationException.h \
src/thrift/TLogging.h \
src/thrift/TToString.h \
- src/thrift/TBase.h
+ src/thrift/TBase.h \
+ src/thrift/TConfiguration.h
include_concurrencydir = $(include_thriftdir)/concurrency
include_concurrency_HEADERS = \
@@ -156,6 +157,10 @@ include_concurrency_HEADERS = \
include_protocoldir = $(include_thriftdir)/protocol
include_protocol_HEADERS = \
+ src/thrift/protocol/TEnum.h \
+ src/thrift/protocol/TList.h \
+ src/thrift/protocol/TSet.h \
+ src/thrift/protocol/TMap.h \
src/thrift/protocol/TBinaryProtocol.h \
src/thrift/protocol/TBinaryProtocol.tcc \
src/thrift/protocol/TCompactProtocol.h \
diff --git a/lib/cpp/src/thrift/TConfiguration.h b/lib/cpp/src/thrift/TConfiguration.h
new file mode 100644
index 000000000..5bff440a0
--- /dev/null
+++ b/lib/cpp/src/thrift/TConfiguration.h
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef THRIFT_TCONFIGURATION_H
+#define THRIFT_TCONFIGURATION_H
+
+namespace apache {
+namespace thrift {
+
+class TConfiguration
+{
+public:
+ TConfiguration(int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE,
+ int maxFrameSize = DEFAULT_MAX_FRAME_SIZE, int recursionLimit = DEFAULT_RECURSION_DEPTH)
+ : maxMessageSize_(maxMessageSize), maxFrameSize_(maxFrameSize), recursionLimit_(recursionLimit) {}
+
+ const static int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
+ const static int DEFAULT_MAX_FRAME_SIZE = 16384000; // this value is used consistently across all Thrift libraries
+ const static int DEFAULT_RECURSION_DEPTH = 64;
+
+ inline int getMaxMessageSize() { return maxMessageSize_; }
+ inline void setMaxMessageSize(int maxMessageSize) { maxMessageSize_ = maxMessageSize; }
+ inline int getMaxFrameSize() { return maxFrameSize_; }
+ inline void setMaxFrameSize(int maxFrameSize) { maxFrameSize_ = maxFrameSize; }
+ inline int getRecursionLimit() { return recursionLimit_; }
+ inline void setRecursionLimit(int recursionLimit) { recursionLimit_ = recursionLimit; }
+
+private:
+ int maxMessageSize_ = DEFAULT_MAX_MESSAGE_SIZE;
+ int maxFrameSize_ = DEFAULT_MAX_FRAME_SIZE;
+ int recursionLimit_ = DEFAULT_RECURSION_DEPTH;
+
+ // TODO(someone_smart): add connection and i/o timeouts
+};
+}
+} // apache::thrift
+
+#endif /* THRIFT_TCONFIGURATION_H */
+
diff --git a/lib/cpp/src/thrift/protocol/TBinaryProtocol.h b/lib/cpp/src/thrift/protocol/TBinaryProtocol.h
index 6bd5fb830..b43144017 100644
--- a/lib/cpp/src/thrift/protocol/TBinaryProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TBinaryProtocol.h
@@ -166,6 +166,24 @@ public:
inline uint32_t readBinary(std::string& str);
+ int getMinSerializedSize(TType type);
+
+ void checkReadBytesAvailable(TSet& set)
+ {
+ trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
+ }
+
+ void checkReadBytesAvailable(TList& list)
+ {
+ trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
+ }
+
+ void checkReadBytesAvailable(TMap& map)
+ {
+ int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
+ trans_->checkReadBytesAvailable(map.size_ * elmSize);
+ }
+
protected:
template <typename StrType>
uint32_t readStringBody(StrType& str, int32_t sz);
diff --git a/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc b/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc
index 2964f25d0..755f24386 100644
--- a/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc
+++ b/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc
@@ -21,6 +21,7 @@
#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_TCC_ 1
#include <thrift/protocol/TBinaryProtocol.h>
+#include <thrift/transport/TTransportException.h>
#include <limits>
@@ -285,6 +286,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readMapBegin(TType& keyType,
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
+
+ TMap map(keyType, valType, size);
+ checkReadBytesAvailable(map);
+
return result;
}
@@ -307,6 +312,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readListBegin(TType& elemType
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
+
+ TList list(elemType, size);
+ checkReadBytesAvailable(list);
+
return result;
}
@@ -329,6 +338,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readSetBegin(TType& elemType,
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
+
+ TSet set(elemType, size);
+ checkReadBytesAvailable(set);
+
return result;
}
@@ -447,6 +460,30 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readStringBody(StrType& str,
this->trans_->readAll(reinterpret_cast<uint8_t*>(&str[0]), size);
return (uint32_t)size;
}
+
+// Return the minimum number of bytes a type will consume on the wire
+template <class Transport_, class ByteOrder_>
+int TBinaryProtocolT<Transport_, ByteOrder_>::getMinSerializedSize(TType type)
+{
+ switch (type)
+ {
+ case T_STOP: return 0;
+ case T_VOID: return 0;
+ case T_BOOL: return sizeof(int8_t);
+ case T_BYTE: return sizeof(int8_t);
+ case T_DOUBLE: return sizeof(double);
+ case T_I16: return sizeof(short);
+ case T_I32: return sizeof(int);
+ case T_I64: return sizeof(long);
+ case T_STRING: return sizeof(int); // string length
+ case T_STRUCT: return 0; // empty struct
+ case T_MAP: return sizeof(int); // element count
+ case T_SET: return sizeof(int); // element count
+ case T_LIST: return sizeof(int); // element count
+ default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
+ }
+}
+
}
}
} // apache::thrift::protocol
diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.h b/lib/cpp/src/thrift/protocol/TCompactProtocol.h
index 2930aba29..6f990b2d6 100644
--- a/lib/cpp/src/thrift/protocol/TCompactProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.h
@@ -140,6 +140,24 @@ public:
uint32_t writeBinary(const std::string& str);
+ int getMinSerializedSize(TType type);
+
+ void checkReadBytesAvailable(TSet& set)
+ {
+ trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
+ }
+
+ void checkReadBytesAvailable(TList& list)
+ {
+ trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
+ }
+
+ void checkReadBytesAvailable(TMap& map)
+ {
+ int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
+ trans_->checkReadBytesAvailable(map.size_ * elmSize);
+ }
+
/**
* These methods are called by structs, but don't actually have any wired
* output or purpose
diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc b/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc
index d1e342efd..16780911c 100644
--- a/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc
+++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc
@@ -538,6 +538,9 @@ uint32_t TCompactProtocolT<Transport_>::readMapBegin(TType& keyType,
valType = getTType((int8_t)((uint8_t)kvType & 0xf));
size = (uint32_t)msize;
+ TMap map(keyType, valType, size);
+ checkReadBytesAvailable(map);
+
return rsize;
}
@@ -570,6 +573,9 @@ uint32_t TCompactProtocolT<Transport_>::readListBegin(TType& elemType,
elemType = getTType((int8_t)(size_and_type & 0x0f));
size = (uint32_t)lsize;
+ TList list(elemType, size);
+ checkReadBytesAvailable(list);
+
return rsize;
}
@@ -706,6 +712,8 @@ uint32_t TCompactProtocolT<Transport_>::readBinary(std::string& str) {
trans_->readAll(string_buf_, size);
str.assign((char*)string_buf_, size);
+ trans_->checkReadBytesAvailable(rsize + (uint32_t)size);
+
return rsize + (uint32_t)size;
}
@@ -821,6 +829,30 @@ TType TCompactProtocolT<Transport_>::getTType(int8_t type) {
}
}
+// Return the minimum number of bytes a type will consume on the wire
+template <class Transport_>
+int TCompactProtocolT<Transport_>::getMinSerializedSize(TType type)
+{
+ switch (type)
+ {
+ case T_STOP: return 0;
+ case T_VOID: return 0;
+ case T_BOOL: return sizeof(int8_t);
+ case T_DOUBLE: return 8; // uses fixedLongToBytes() which always writes 8 bytes
+ case T_BYTE: return sizeof(int8_t);
+ case T_I16: return sizeof(int8_t); // zigzag
+ case T_I32: return sizeof(int8_t); // zigzag
+ case T_I64: return sizeof(int8_t); // zigzag
+ case T_STRING: return sizeof(int8_t); // string length
+ case T_STRUCT: return 0; // empty struct
+ case T_MAP: return sizeof(int8_t); // element count
+ case T_SET: return sizeof(int8_t); // element count
+ case T_LIST: return sizeof(int8_t); // element count
+ default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
+ }
+}
+
+
}}} // apache::thrift::protocol
#endif // _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_TCC_
diff --git a/lib/cpp/src/thrift/protocol/TEnum.h b/lib/cpp/src/thrift/protocol/TEnum.h
new file mode 100644
index 000000000..9636785e3
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TEnum.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_ENUM_H_
+#define _THRIFT_ENUM_H_
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+/**
+ * Enumerated definition of the types that the Thrift protocol supports.
+ * Take special note of the T_END type which is used specifically to mark
+ * the end of a sequence of fields.
+ */
+enum TType {
+ T_STOP = 0,
+ T_VOID = 1,
+ T_BOOL = 2,
+ T_BYTE = 3,
+ T_I08 = 3,
+ T_I16 = 6,
+ T_I32 = 8,
+ T_U64 = 9,
+ T_I64 = 10,
+ T_DOUBLE = 4,
+ T_STRING = 11,
+ T_UTF7 = 11,
+ T_STRUCT = 12,
+ T_MAP = 13,
+ T_SET = 14,
+ T_LIST = 15,
+ T_UTF8 = 16,
+ T_UTF16 = 17
+};
+
+/**
+ * Enumerated definition of the message types that the Thrift protocol
+ * supports.
+ */
+enum TMessageType {
+ T_CALL = 1,
+ T_REPLY = 2,
+ T_EXCEPTION = 3,
+ T_ONEWAY = 4
+};
+
+}}} // apache::thrift::protocol
+
+#endif // #define _THRIFT_ENUM_H_
diff --git a/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp b/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp
index 28d0da299..6e4e8ef0d 100644
--- a/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp
+++ b/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp
@@ -1013,6 +1013,10 @@ uint32_t TJSONProtocol::readMapBegin(TType& keyType, TType& valType, uint32_t& s
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
result += readJSONObjectStart();
+
+ TMap map(keyType, valType, size);
+ checkReadBytesAvailable(map);
+
return result;
}
@@ -1032,6 +1036,10 @@ uint32_t TJSONProtocol::readListBegin(TType& elemType, uint32_t& size) {
if (tmpVal > (std::numeric_limits<uint32_t>::max)())
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
+
+ TList list(elemType, size);
+ checkReadBytesAvailable(list);
+
return result;
}
@@ -1049,6 +1057,10 @@ uint32_t TJSONProtocol::readSetBegin(TType& elemType, uint32_t& size) {
if (tmpVal > (std::numeric_limits<uint32_t>::max)())
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
+
+ TSet set(elemType, size);
+ checkReadBytesAvailable(set);
+
return result;
}
@@ -1093,6 +1105,29 @@ uint32_t TJSONProtocol::readString(std::string& str) {
uint32_t TJSONProtocol::readBinary(std::string& str) {
return readJSONBase64(str);
}
+
+// Return the minimum number of bytes a type will consume on the wire
+int TJSONProtocol::getMinSerializedSize(TType type)
+{
+ switch (type)
+ {
+ case T_STOP: return 0;
+ case T_VOID: return 0;
+ case T_BOOL: return 1; // written as int
+ case T_BYTE: return 1;
+ case T_DOUBLE: return 1;
+ case T_I16: return 1;
+ case T_I32: return 1;
+ case T_I64: return 1;
+ case T_STRING: return 2; // empty string
+ case T_STRUCT: return 2; // empty struct
+ case T_MAP: return 2; // empty map
+ case T_SET: return 2; // empty set
+ case T_LIST: return 2; // empty list
+ default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
+ }
+}
+
}
}
} // apache::thrift::protocol
diff --git a/lib/cpp/src/thrift/protocol/TJSONProtocol.h b/lib/cpp/src/thrift/protocol/TJSONProtocol.h
index 420995ef3..e775240ab 100644
--- a/lib/cpp/src/thrift/protocol/TJSONProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TJSONProtocol.h
@@ -245,6 +245,24 @@ public:
uint32_t readBinary(std::string& str);
+ int getMinSerializedSize(TType type);
+
+ void checkReadBytesAvailable(TSet& set)
+ {
+ trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
+ }
+
+ void checkReadBytesAvailable(TList& list)
+ {
+ trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
+ }
+
+ void checkReadBytesAvailable(TMap& map)
+ {
+ int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
+ trans_->checkReadBytesAvailable(map.size_ * elmSize);
+ }
+
class LookaheadReader {
public:
diff --git a/lib/cpp/src/thrift/protocol/TList.h b/lib/cpp/src/thrift/protocol/TList.h
new file mode 100644
index 000000000..bf2c1f9de
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TList.h
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TLIST_H_
+#define _THRIFT_TLIST_H_
+
+#include <thrift/protocol/TEnum.h>
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+// using namespace apache::thrift::protocol;
+
+/**
+ * Helper class that encapsulates list metadata.
+ *
+ */
+class TList {
+public:
+ TList() : elemType_(T_STOP),
+ size_(0) {
+
+ }
+
+ TList(TType t = T_STOP, int s = 0)
+ : elemType_(t),
+ size_(s) {
+
+ }
+
+ TType elemType_;
+ int size_;
+};
+}
+}
+} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_TLIST_H_
diff --git a/lib/cpp/src/thrift/protocol/TMap.h b/lib/cpp/src/thrift/protocol/TMap.h
new file mode 100644
index 000000000..b52ea8faf
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TMap.h
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TMAP_H_
+#define _THRIFT_TMAP_H_
+
+#include <thrift/protocol/TEnum.h>
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+using namespace apache::thrift::protocol;
+
+/**
+ * Helper class that encapsulates map metadata.
+ *
+ */
+class TMap {
+public:
+ TMap()
+ : keyType_(T_STOP),
+ valueType_(T_STOP),
+ size_(0) {
+
+ }
+
+ TMap(TType k, TType v, int s)
+ : keyType_(k),
+ valueType_(v),
+ size_(s) {
+
+ }
+
+ TType keyType_;
+ TType valueType_;
+ int size_;
+};
+}
+}
+} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_TMAP_H_
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index df9c5c39b..867ceb079 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -27,6 +27,10 @@
#include <thrift/transport/TTransport.h>
#include <thrift/protocol/TProtocolException.h>
+#include <thrift/protocol/TEnum.h>
+#include <thrift/protocol/TList.h>
+#include <thrift/protocol/TSet.h>
+#include <thrift/protocol/TMap.h>
#include <memory>
@@ -171,45 +175,6 @@ namespace protocol {
using apache::thrift::transport::TTransport;
/**
- * Enumerated definition of the types that the Thrift protocol supports.
- * Take special note of the T_END type which is used specifically to mark
- * the end of a sequence of fields.
- */
-enum TType {
- T_STOP = 0,
- T_VOID = 1,
- T_BOOL = 2,
- T_BYTE = 3,
- T_I08 = 3,
- T_I16 = 6,
- T_I32 = 8,
- T_U64 = 9,
- T_I64 = 10,
- T_DOUBLE = 4,
- T_STRING = 11,
- T_UTF7 = 11,
- T_STRUCT = 12,
- T_MAP = 13,
- T_SET = 14,
- T_LIST = 15,
- T_UTF8 = 16,
- T_UTF16 = 17
-};
-
-/**
- * Enumerated definition of the message types that the Thrift protocol
- * supports.
- */
-enum TMessageType {
- T_CALL = 1,
- T_REPLY = 2,
- T_EXCEPTION = 3,
- T_ONEWAY = 4
-};
-
-static const uint32_t DEFAULT_RECURSION_LIMIT = 64;
-
-/**
* Abstract class for a thrift protocol driver. These are all the methods that
* a protocol must implement. Essentially, there must be some way of reading
* and writing all the base types, plus a mechanism for writing out structs
@@ -578,11 +543,34 @@ public:
uint32_t getRecursionLimit() const {return recursion_limit_;}
void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}
+ // Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
+ virtual int getMinSerializedSize(TType type) {
+ THRIFT_UNUSED_VARIABLE(type);
+ return 0;
+ }
+
protected:
TProtocol(std::shared_ptr<TTransport> ptrans)
- : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
+ : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0),
+ recursion_limit_(ptrans->getConfiguration()->getRecursionLimit())
{}
+ virtual void checkReadBytesAvailable(TSet& set)
+ {
+ ptrans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
+ }
+
+ virtual void checkReadBytesAvailable(TList& list)
+ {
+ ptrans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
+ }
+
+ virtual void checkReadBytesAvailable(TMap& map)
+ {
+ int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
+ ptrans_->checkReadBytesAvailable(map.size_ * elmSize);
+ }
+
std::shared_ptr<TTransport> ptrans_;
private:
diff --git a/lib/cpp/src/thrift/protocol/TSet.h b/lib/cpp/src/thrift/protocol/TSet.h
new file mode 100644
index 000000000..3a4718cdc
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TSet.h
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TSET_H_
+#define _THRIFT_TSET_H_
+
+#include <thrift/protocol/TEnum.h>
+#include <thrift/protocol/TList.h>
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+using namespace apache::thrift::protocol;
+
+/**
+ * Helper class that encapsulates set metadata.
+ *
+ */
+class TSet {
+public:
+ TSet() : elemType_(T_STOP), size_(0) {
+
+ }
+
+ TSet(TType t, int s)
+ : elemType_(t),
+ size_(s) {
+
+ }
+
+ TSet(TList list)
+ : elemType_(list.elemType_),
+ size_(list.size_) {
+
+ }
+
+ TType elemType_;
+ int size_;
+};
+}
+}
+} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_TSET_H_
diff --git a/lib/cpp/src/thrift/transport/TBufferTransports.cpp b/lib/cpp/src/thrift/transport/TBufferTransports.cpp
index d8a1b3e28..45c0c9bf9 100644
--- a/lib/cpp/src/thrift/transport/TBufferTransports.cpp
+++ b/lib/cpp/src/thrift/transport/TBufferTransports.cpp
@@ -118,6 +118,7 @@ const uint8_t* TBufferedTransport::borrowSlow(uint8_t* buf, uint32_t* len) {
}
void TBufferedTransport::flush() {
+ resetConsumedMessageSize();
// Write out any data waiting in the write buffer.
auto have_bytes = static_cast<uint32_t>(wBase_ - wBuf_.get());
if (have_bytes > 0) {
@@ -248,6 +249,7 @@ void TFramedTransport::writeSlow(const uint8_t* buf, uint32_t len) {
}
void TFramedTransport::flush() {
+ resetConsumedMessageSize();
int32_t sz_hbo, sz_nbo;
assert(wBufSize_ > sizeof(sz_nbo));
diff --git a/lib/cpp/src/thrift/transport/TBufferTransports.h b/lib/cpp/src/thrift/transport/TBufferTransports.h
index 86f0c5acc..179934ba0 100644
--- a/lib/cpp/src/thrift/transport/TBufferTransports.h
+++ b/lib/cpp/src/thrift/transport/TBufferTransports.h
@@ -62,6 +62,7 @@ public:
* This method is meant to eventually be nonvirtual and inlinable.
*/
uint32_t read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint8_t* new_rBase = rBase_ + len;
if (TDB_LIKELY(new_rBase <= rBound_)) {
std::memcpy(buf, rBase_, len);
@@ -120,6 +121,7 @@ public:
* Consume doesn't require a slow path.
*/
void consume(uint32_t len) {
+ countConsumedMessageBytes(len);
if (TDB_LIKELY(static_cast<ptrdiff_t>(len) <= rBound_ - rBase_)) {
rBase_ += len;
} else {
@@ -148,7 +150,8 @@ protected:
* performance-sensitive operation, so it is okay to just leave it to
* the concrete class to set up pointers correctly.
*/
- TBufferBase() : rBase_(nullptr), rBound_(nullptr), wBase_(nullptr), wBound_(nullptr) {}
+ TBufferBase(std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config), rBase_(nullptr), rBound_(nullptr), wBase_(nullptr), wBound_(nullptr) {}
/// Convenience mutator for setting the read buffer.
void setReadBuffer(uint8_t* buf, uint32_t len) {
@@ -186,8 +189,9 @@ public:
static const int DEFAULT_BUFFER_SIZE = 512;
/// Use default buffer sizes.
- TBufferedTransport(std::shared_ptr<TTransport> transport)
- : transport_(transport),
+ TBufferedTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(DEFAULT_BUFFER_SIZE),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(new uint8_t[rBufSize_]),
@@ -196,8 +200,9 @@ public:
}
/// Use specified buffer sizes.
- TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t sz)
- : transport_(transport),
+ TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t sz, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(sz),
wBufSize_(sz),
rBuf_(new uint8_t[rBufSize_]),
@@ -206,8 +211,10 @@ public:
}
/// Use specified read and write buffer sizes.
- TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz)
- : transport_(transport),
+ TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(rsz),
wBufSize_(wsz),
rBuf_(new uint8_t[rBufSize_]),
@@ -309,8 +316,9 @@ public:
static const int DEFAULT_MAX_FRAME_SIZE = 256 * 1024 * 1024;
/// Use default buffer sizes.
- TFramedTransport()
- : transport_(),
+ TFramedTransport(std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(),
rBufSize_(0),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(),
@@ -319,27 +327,30 @@ public:
initPointers();
}
- TFramedTransport(std::shared_ptr<TTransport> transport)
- : transport_(transport),
+ TFramedTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(0),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(),
wBuf_(new uint8_t[wBufSize_]),
bufReclaimThresh_((std::numeric_limits<uint32_t>::max)()),
- maxFrameSize_(DEFAULT_MAX_FRAME_SIZE) {
+ maxFrameSize_(configuration_->getMaxFrameSize()) {
initPointers();
}
TFramedTransport(std::shared_ptr<TTransport> transport,
uint32_t sz,
- uint32_t bufReclaimThresh = (std::numeric_limits<uint32_t>::max)())
- : transport_(transport),
+ uint32_t bufReclaimThresh = (std::numeric_limits<uint32_t>::max)(),
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(0),
wBufSize_(sz),
rBuf_(),
wBuf_(new uint8_t[wBufSize_]),
bufReclaimThresh_(bufReclaimThresh),
- maxFrameSize_(DEFAULT_MAX_FRAME_SIZE) {
+ maxFrameSize_(configuration_->getMaxFrameSize()) {
initPointers();
}
@@ -503,7 +514,10 @@ public:
* Construct a TMemoryBuffer with a default-sized buffer,
* owned by the TMemoryBuffer object.
*/
- TMemoryBuffer() { initCommon(nullptr, defaultSize, true, 0); }
+ TMemoryBuffer(std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config) {
+ initCommon(nullptr, defaultSize, true, 0);
+ }
/**
* Construct a TMemoryBuffer with a buffer of a specified size,
@@ -511,7 +525,10 @@ public:
*
* @param sz The initial size of the buffer.
*/
- TMemoryBuffer(uint32_t sz) { initCommon(nullptr, sz, true, 0); }
+ TMemoryBuffer(uint32_t sz, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config) {
+ initCommon(nullptr, sz, true, 0);
+ }
/**
* Construct a TMemoryBuffer with buf as its initial contents.
@@ -523,7 +540,8 @@ public:
* @param sz The size of @c buf.
* @param policy See @link MemoryPolicy @endlink .
*/
- TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) {
+ TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config) {
if (buf == nullptr && sz != 0) {
throw TTransportException(TTransportException::BAD_ARGS,
"TMemoryBuffer given null buffer with non-zero size.");
diff --git a/lib/cpp/src/thrift/transport/TFDTransport.cpp b/lib/cpp/src/thrift/transport/TFDTransport.cpp
index 93dd10021..fa7f0dabe 100644
--- a/lib/cpp/src/thrift/transport/TFDTransport.cpp
+++ b/lib/cpp/src/thrift/transport/TFDTransport.cpp
@@ -52,6 +52,7 @@ void TFDTransport::close() {
}
uint32_t TFDTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
unsigned int maxRetries = 5; // same as the TSocket default
unsigned int retries = 0;
while (true) {
diff --git a/lib/cpp/src/thrift/transport/TFDTransport.h b/lib/cpp/src/thrift/transport/TFDTransport.h
index a3cf51948..fb84c9d8d 100644
--- a/lib/cpp/src/thrift/transport/TFDTransport.h
+++ b/lib/cpp/src/thrift/transport/TFDTransport.h
@@ -40,8 +40,10 @@ class TFDTransport : public TVirtualTransport<TFDTransport> {
public:
enum ClosePolicy { NO_CLOSE_ON_DESTROY = 0, CLOSE_ON_DESTROY = 1 };
- TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY)
- : fd_(fd), close_policy_(close_policy) {}
+ TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config), fd_(fd), close_policy_(close_policy) {
+ }
~TFDTransport() override {
if (close_policy_ == CLOSE_ON_DESTROY) {
diff --git a/lib/cpp/src/thrift/transport/TFileTransport.cpp b/lib/cpp/src/thrift/transport/TFileTransport.cpp
index eaf2bc365..08372b3e2 100644
--- a/lib/cpp/src/thrift/transport/TFileTransport.cpp
+++ b/lib/cpp/src/thrift/transport/TFileTransport.cpp
@@ -63,8 +63,9 @@ using std::string;
using namespace apache::thrift::protocol;
using namespace apache::thrift::concurrency;
-TFileTransport::TFileTransport(string path, bool readOnly)
- : readState_(),
+TFileTransport::TFileTransport(string path, bool readOnly, std::shared_ptr<TConfiguration> config)
+ : TTransport(config),
+ readState_(),
readBuff_(nullptr),
currentEvent_(nullptr),
readBuffSize_(DEFAULT_READ_BUFF_SIZE),
@@ -519,6 +520,7 @@ void TFileTransport::writerThread() {
}
void TFileTransport::flush() {
+ resetConsumedMessageSize();
// file must be open for writing for any flushing to take place
if (!writerThread_.get()) {
return;
@@ -537,6 +539,7 @@ void TFileTransport::flush() {
}
uint32_t TFileTransport::readAll(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint32_t have = 0;
uint32_t get = 0;
@@ -568,6 +571,7 @@ bool TFileTransport::peek() {
}
uint32_t TFileTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
// check if there an event is ready to be read
if (!currentEvent_) {
currentEvent_ = readEvent();
diff --git a/lib/cpp/src/thrift/transport/TFileTransport.h b/lib/cpp/src/thrift/transport/TFileTransport.h
index 0df5cf909..608cff184 100644
--- a/lib/cpp/src/thrift/transport/TFileTransport.h
+++ b/lib/cpp/src/thrift/transport/TFileTransport.h
@@ -173,7 +173,7 @@ public:
*/
class TFileTransport : public TFileReaderTransport, public TFileWriterTransport {
public:
- TFileTransport(std::string path, bool readOnly = false);
+ TFileTransport(std::string path, bool readOnly = false, std::shared_ptr<TConfiguration> config = nullptr);
~TFileTransport() override;
// TODO: what is the correct behaviour for this?
diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.cpp b/lib/cpp/src/thrift/transport/THeaderTransport.cpp
index b582d8da7..b3b833389 100644
--- a/lib/cpp/src/thrift/transport/THeaderTransport.cpp
+++ b/lib/cpp/src/thrift/transport/THeaderTransport.cpp
@@ -415,6 +415,7 @@ void THeaderTransport::clearHeaders() {
}
void THeaderTransport::flush() {
+ resetConsumedMessageSize();
// Write out any data waiting in the write buffer.
uint32_t haveBytes = getWriteBytes();
diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.h b/lib/cpp/src/thrift/transport/THeaderTransport.h
index d1e9d4339..63a4ac880 100644
--- a/lib/cpp/src/thrift/transport/THeaderTransport.h
+++ b/lib/cpp/src/thrift/transport/THeaderTransport.h
@@ -74,8 +74,9 @@ public:
static const int THRIFT_MAX_VARINT32_BYTES = 5;
/// Use default buffer sizes.
- explicit THeaderTransport(const std::shared_ptr<TTransport>& transport)
- : TVirtualTransport(transport),
+ explicit THeaderTransport(const std::shared_ptr<TTransport>& transport,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(transport, config),
outTransport_(transport),
protoId(T_COMPACT_PROTOCOL),
clientType(THRIFT_HEADER_CLIENT_TYPE),
@@ -88,8 +89,9 @@ public:
}
THeaderTransport(const std::shared_ptr<TTransport> inTransport,
- const std::shared_ptr<TTransport> outTransport)
- : TVirtualTransport(inTransport),
+ const std::shared_ptr<TTransport> outTransport,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(inTransport, config),
outTransport_(outTransport),
protoId(T_COMPACT_PROTOCOL),
clientType(THRIFT_HEADER_CLIENT_TYPE),
diff --git a/lib/cpp/src/thrift/transport/THttpClient.cpp b/lib/cpp/src/thrift/transport/THttpClient.cpp
index fdee787c6..ea2eb99af 100644
--- a/lib/cpp/src/thrift/transport/THttpClient.cpp
+++ b/lib/cpp/src/thrift/transport/THttpClient.cpp
@@ -34,12 +34,16 @@ namespace transport {
THttpClient::THttpClient(std::shared_ptr<TTransport> transport,
std::string host,
- std::string path)
- : THttpTransport(transport), host_(host), path_(path) {
+ std::string path,
+ std::shared_ptr<TConfiguration> config)
+ : THttpTransport(transport, config),
+ host_(host),
+ path_(path) {
}
-THttpClient::THttpClient(string host, int port, string path)
- : THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port))),
+THttpClient::THttpClient(string host, int port, string path,
+ std::shared_ptr<TConfiguration> config)
+ : THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port)), config),
host_(host),
path_(path) {
}
@@ -93,6 +97,7 @@ bool THttpClient::parseStatusLine(char* status) {
}
void THttpClient::flush() {
+ resetConsumedMessageSize();
// Fetch the contents of the write buffer
uint8_t* buf;
uint32_t len;
diff --git a/lib/cpp/src/thrift/transport/THttpClient.h b/lib/cpp/src/thrift/transport/THttpClient.h
index 81ddc56c5..f0d7e8b27 100644
--- a/lib/cpp/src/thrift/transport/THttpClient.h
+++ b/lib/cpp/src/thrift/transport/THttpClient.h
@@ -40,13 +40,16 @@ public:
*/
THttpClient(std::shared_ptr<TTransport> transport,
std::string host = "localhost",
- std::string path = "/service");
+ std::string path = "/service",
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* @brief Constructor that will create a new socket transport using the host
* and port.
*/
- THttpClient(std::string host, int port, std::string path = "");
+ THttpClient(std::string host, int port,
+ std::string path = "",
+ std::shared_ptr<TConfiguration> config = nullptr);
~THttpClient() override;
diff --git a/lib/cpp/src/thrift/transport/THttpServer.cpp b/lib/cpp/src/thrift/transport/THttpServer.cpp
index 98518fd55..91a1c39af 100644
--- a/lib/cpp/src/thrift/transport/THttpServer.cpp
+++ b/lib/cpp/src/thrift/transport/THttpServer.cpp
@@ -34,7 +34,9 @@ namespace apache {
namespace thrift {
namespace transport {
-THttpServer::THttpServer(std::shared_ptr<TTransport> transport) : THttpTransport(transport) {
+THttpServer::THttpServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config)
+ : THttpTransport(transport, config) {
+
}
THttpServer::~THttpServer() = default;
@@ -118,6 +120,7 @@ bool THttpServer::parseStatusLine(char* status) {
}
void THttpServer::flush() {
+ resetConsumedMessageSize();
// Fetch the contents of the write buffer
uint8_t* buf;
uint32_t len;
diff --git a/lib/cpp/src/thrift/transport/THttpServer.h b/lib/cpp/src/thrift/transport/THttpServer.h
index d2196911c..bc98986d7 100644
--- a/lib/cpp/src/thrift/transport/THttpServer.h
+++ b/lib/cpp/src/thrift/transport/THttpServer.h
@@ -28,7 +28,7 @@ namespace transport {
class THttpServer : public THttpTransport {
public:
- THttpServer(std::shared_ptr<TTransport> transport);
+ THttpServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr);
~THttpServer() override;
diff --git a/lib/cpp/src/thrift/transport/THttpTransport.cpp b/lib/cpp/src/thrift/transport/THttpTransport.cpp
index aea2b2847..305221e9d 100644
--- a/lib/cpp/src/thrift/transport/THttpTransport.cpp
+++ b/lib/cpp/src/thrift/transport/THttpTransport.cpp
@@ -31,8 +31,9 @@ namespace transport {
const char* THttpTransport::CRLF = "\r\n";
const int THttpTransport::CRLF_LEN = 2;
-THttpTransport::THttpTransport(std::shared_ptr<TTransport> transport)
- : transport_(transport),
+THttpTransport::THttpTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ transport_(transport),
origin_(""),
readHeaders_(true),
chunked_(false),
@@ -61,6 +62,7 @@ THttpTransport::~THttpTransport() {
}
uint32_t THttpTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
if (readBuffer_.available_read() == 0) {
readBuffer_.resetBuffer();
uint32_t got = readMoreData();
diff --git a/lib/cpp/src/thrift/transport/THttpTransport.h b/lib/cpp/src/thrift/transport/THttpTransport.h
index 75f0d8c07..5d2bd37fe 100644
--- a/lib/cpp/src/thrift/transport/THttpTransport.h
+++ b/lib/cpp/src/thrift/transport/THttpTransport.h
@@ -36,7 +36,7 @@ namespace transport {
*/
class THttpTransport : public TVirtualTransport<THttpTransport> {
public:
- THttpTransport(std::shared_ptr<TTransport> transport);
+ THttpTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr);
~THttpTransport() override;
@@ -54,7 +54,9 @@ public:
void write(const uint8_t* buf, uint32_t len);
- void flush() override = 0;
+ void flush() override {
+ resetConsumedMessageSize();
+ };
const std::string getOrigin() const override;
diff --git a/lib/cpp/src/thrift/transport/TPipe.cpp b/lib/cpp/src/thrift/transport/TPipe.cpp
index 4c2fea927..953cec167 100644
--- a/lib/cpp/src/thrift/transport/TPipe.cpp
+++ b/lib/cpp/src/thrift/transport/TPipe.cpp
@@ -222,30 +222,35 @@ uint32_t pseudo_sync_read(HANDLE pipe, HANDLE event, uint8_t* buf, uint32_t len)
}
//---- Constructors ----
-TPipe::TPipe(TAutoHandle &Pipe)
- : impl_(new TWaitableNamedPipeImpl(Pipe)), TimeoutSeconds_(3), isAnonymous_(false) {
+TPipe::TPipe(TAutoHandle &Pipe, std::shared_ptr<TConfiguration> config)
+ : impl_(new TWaitableNamedPipeImpl(Pipe)), TimeoutSeconds_(3),
+ isAnonymous_(false), TVirtualTransport(config) {
}
-TPipe::TPipe(HANDLE Pipe)
- : TimeoutSeconds_(3), isAnonymous_(false)
+TPipe::TPipe(HANDLE Pipe, std::shared_ptr<TConfiguration> config)
+ : TimeoutSeconds_(3), isAnonymous_(false), TVirtualTransport(config)
{
TAutoHandle pipeHandle(Pipe);
impl_.reset(new TWaitableNamedPipeImpl(pipeHandle));
}
-TPipe::TPipe(const char* pipename) : TimeoutSeconds_(3), isAnonymous_(false) {
+TPipe::TPipe(const char* pipename, std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3),
+ isAnonymous_(false), TVirtualTransport(config) {
setPipename(pipename);
}
-TPipe::TPipe(const std::string& pipename) : TimeoutSeconds_(3), isAnonymous_(false) {
+TPipe::TPipe(const std::string& pipename, std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3),
+ isAnonymous_(false), TVirtualTransport(config) {
setPipename(pipename);
}
-TPipe::TPipe(HANDLE PipeRd, HANDLE PipeWrt)
- : impl_(new TAnonPipeImpl(PipeRd, PipeWrt)), TimeoutSeconds_(3), isAnonymous_(true) {
+TPipe::TPipe(HANDLE PipeRd, HANDLE PipeWrt, std::shared_ptr<TConfiguration> config)
+ : impl_(new TAnonPipeImpl(PipeRd, PipeWrt)), TimeoutSeconds_(3), isAnonymous_(true),
+ TVirtualTransport(config) {
}
-TPipe::TPipe() : TimeoutSeconds_(3), isAnonymous_(false) {
+TPipe::TPipe(std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3), isAnonymous_(false),
+ TVirtualTransport(config) {
}
TPipe::~TPipe() {
@@ -299,6 +304,7 @@ void TPipe::close() {
}
uint32_t TPipe::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
if (!isOpen())
throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open pipe");
return impl_->read(buf, len);
diff --git a/lib/cpp/src/thrift/transport/TPipe.h b/lib/cpp/src/thrift/transport/TPipe.h
index ba149b109..7795151a6 100644
--- a/lib/cpp/src/thrift/transport/TPipe.h
+++ b/lib/cpp/src/thrift/transport/TPipe.h
@@ -49,15 +49,15 @@ class TPipeImpl;
class TPipe : public TVirtualTransport<TPipe> {
public:
// Constructs a new pipe object.
- TPipe();
+ TPipe(std::shared_ptr<TConfiguration> config = nullptr);
// Named pipe constructors -
- explicit TPipe(HANDLE Pipe); // HANDLE is a void*
- explicit TPipe(TAutoHandle& Pipe); // this ctor will clear out / move from Pipe
+ explicit TPipe(HANDLE Pipe, std::shared_ptr<TConfiguration> config = nullptr); // HANDLE is a void*
+ explicit TPipe(TAutoHandle& Pipe, std::shared_ptr<TConfiguration> config = nullptr); // this ctor will clear out / move from Pipe
// need a const char * overload so string literals don't go to the HANDLE overload
- explicit TPipe(const char* pipename);
- explicit TPipe(const std::string& pipename);
+ explicit TPipe(const char* pipename, std::shared_ptr<TConfiguration> config = nullptr);
+ explicit TPipe(const std::string& pipename, std::shared_ptr<TConfiguration> config = nullptr);
// Anonymous pipe -
- TPipe(HANDLE PipeRd, HANDLE PipeWrt);
+ TPipe(HANDLE PipeRd, HANDLE PipeWrt, std::shared_ptr<TConfiguration> config = nullptr);
// Destroys the pipe object, closing it if necessary.
virtual ~TPipe();
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
index aa76980aa..9efc5fc50 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
@@ -214,34 +214,37 @@ SSL* SSLContext::createSSL() {
}
// TSSLSocket implementation
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx)
- : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config)
+ : TSocket(config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener)
- : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config)
+ : TSocket(config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
interruptListener_ = interruptListener;
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket)
- : TSocket(socket), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config)
+ : TSocket(socket, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener)
- : TSocket(socket, interruptListener), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config)
+ : TSocket(socket, interruptListener, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port)
- : TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<TConfiguration> config)
+ : TSocket(host, port, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener)
- : TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config)
+ : TSocket(host, port, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
interruptListener_ = interruptListener;
}
@@ -391,6 +394,7 @@ void TSSLSocket::close() {
* exception incase of failure.
*/
uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
initializeHandshake();
if (!checkHandshake())
throw TTransportException(TTransportException::UNKNOWN, "retry again");
@@ -553,6 +557,7 @@ uint32_t TSSLSocket::write_partial(const uint8_t* buf, uint32_t len) {
}
void TSSLSocket::flush() {
+ resetConsumedMessageSize();
// Don't throw exception if not open. Thrift servers close socket twice.
if (ssl_ == nullptr) {
return;
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h
index a78112c89..5afc571f9 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.h
@@ -111,37 +111,40 @@ protected:
/**
* Constructor.
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor with an interrupt signal.
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor, create an instance of TSSLSocket given an existing socket.
*
* @param socket An existing socket
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted.
*
* @param socket An existing socket
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor.
*
* @param host Remote host name
* @param port Remote port number
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor with an interrupt signal.
*
* @param host Remote host name
* @param port Remote port number
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* Authorize peer access after SSL handshake completes.
*/
diff --git a/lib/cpp/src/thrift/transport/TShortReadTransport.h b/lib/cpp/src/thrift/transport/TShortReadTransport.h
index 185c78dc7..c99e6a72b 100644
--- a/lib/cpp/src/thrift/transport/TShortReadTransport.h
+++ b/lib/cpp/src/thrift/transport/TShortReadTransport.h
@@ -38,8 +38,10 @@ namespace test {
*/
class TShortReadTransport : public TVirtualTransport<TShortReadTransport> {
public:
- TShortReadTransport(std::shared_ptr<TTransport> transport, double full_prob)
- : transport_(transport), fullProb_(full_prob) {}
+ TShortReadTransport(std::shared_ptr<TTransport> transport, double full_prob,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config), transport_(transport), fullProb_(full_prob) {
+ }
bool isOpen() const override { return transport_->isOpen(); }
@@ -50,6 +52,7 @@ public:
void close() override { transport_->close(); }
uint32_t read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
if (len == 0) {
return 0;
}
@@ -62,11 +65,17 @@ public:
void write(const uint8_t* buf, uint32_t len) { transport_->write(buf, len); }
- void flush() override { transport_->flush(); }
+ void flush() override {
+ resetConsumedMessageSize();
+ transport_->flush();
+ }
const uint8_t* borrow(uint8_t* buf, uint32_t* len) { return transport_->borrow(buf, len); }
- void consume(uint32_t len) { return transport_->consume(len); }
+ void consume(uint32_t len) {
+ countConsumedMessageBytes(len);
+ return transport_->consume(len);
+ }
std::shared_ptr<TTransport> getUnderlyingTransport() { return transport_; }
diff --git a/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp b/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp
index 4b1399e14..c41affb79 100644
--- a/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp
+++ b/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp
@@ -35,8 +35,8 @@ namespace apache {
namespace thrift {
namespace transport {
-TSimpleFileTransport::TSimpleFileTransport(const std::string& path, bool read, bool write)
- : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY) {
+TSimpleFileTransport::TSimpleFileTransport(const std::string& path, bool read, bool write, std::shared_ptr<TConfiguration> config)
+ : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY, config) {
int flags = 0;
if (read && write) {
flags = O_RDWR;
diff --git a/lib/cpp/src/thrift/transport/TSimpleFileTransport.h b/lib/cpp/src/thrift/transport/TSimpleFileTransport.h
index 32e18974d..24741b0f3 100644
--- a/lib/cpp/src/thrift/transport/TSimpleFileTransport.h
+++ b/lib/cpp/src/thrift/transport/TSimpleFileTransport.h
@@ -33,7 +33,8 @@ namespace transport {
*/
class TSimpleFileTransport : public TFDTransport {
public:
- TSimpleFileTransport(const std::string& path, bool read = true, bool write = false);
+ TSimpleFileTransport(const std::string& path, bool read = true, bool write = false,
+ std::shared_ptr<TConfiguration> config = nullptr);
};
}
}
diff --git a/lib/cpp/src/thrift/transport/TSocket.cpp b/lib/cpp/src/thrift/transport/TSocket.cpp
index a1a6dfb2f..81aaccf43 100644
--- a/lib/cpp/src/thrift/transport/TSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSocket.cpp
@@ -77,8 +77,9 @@ namespace transport {
*
*/
-TSocket::TSocket(const string& host, int port)
- : host_(host),
+TSocket::TSocket(const string& host, int port, std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ host_(host),
port_(port),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
@@ -92,8 +93,9 @@ TSocket::TSocket(const string& host, int port)
maxRecvRetries_(5) {
}
-TSocket::TSocket(const string& path)
- : port_(0),
+TSocket::TSocket(const string& path, std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ port_(0),
path_(path),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
@@ -108,8 +110,9 @@ TSocket::TSocket(const string& path)
cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
}
-TSocket::TSocket()
- : port_(0),
+TSocket::TSocket(std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ port_(0),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
connTimeout_(0),
@@ -123,8 +126,9 @@ TSocket::TSocket()
cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
}
-TSocket::TSocket(THRIFT_SOCKET socket)
- : port_(0),
+TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ port_(0),
socket_(socket),
peerPort_(0),
connTimeout_(0),
@@ -144,8 +148,10 @@ TSocket::TSocket(THRIFT_SOCKET socket)
#endif
}
-TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener)
- : port_(0),
+TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ port_(0),
socket_(socket),
peerPort_(0),
interruptListener_(interruptListener),
@@ -522,6 +528,7 @@ void TSocket::setSocketFD(THRIFT_SOCKET socket) {
}
uint32_t TSocket::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
if (socket_ == THRIFT_INVALID_SOCKET) {
throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
}
diff --git a/lib/cpp/src/thrift/transport/TSocket.h b/lib/cpp/src/thrift/transport/TSocket.h
index b0e8ade83..043f0de8d 100644
--- a/lib/cpp/src/thrift/transport/TSocket.h
+++ b/lib/cpp/src/thrift/transport/TSocket.h
@@ -52,7 +52,7 @@ public:
* socket.
*
*/
- TSocket();
+ TSocket(std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructs a new socket. Note that this does NOT actually connect the
@@ -61,7 +61,7 @@ public:
* @param host An IP address or hostname to connect to
* @param port The port to connect on
*/
- TSocket(const std::string& host, int port);
+ TSocket(const std::string& host, int port, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructs a new Unix domain socket.
@@ -69,7 +69,7 @@ public:
*
* @param path The Unix domain socket e.g. "/tmp/ThriftTest.binary.thrift"
*/
- TSocket(const std::string& path);
+ TSocket(const std::string& path, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Destroyes the socket object, closing it if necessary.
@@ -264,13 +264,14 @@ public:
/**
* Constructor to create socket from file descriptor.
*/
- TSocket(THRIFT_SOCKET socket);
+ TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor to create socket from file descriptor that
* can be interrupted safely.
*/
- TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
+ TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* Set a cache of the peer address (used when trivially available: e.g.
diff --git a/lib/cpp/src/thrift/transport/TTransport.h b/lib/cpp/src/thrift/transport/TTransport.h
index 63978829f..5f657f89b 100644
--- a/lib/cpp/src/thrift/transport/TTransport.h
+++ b/lib/cpp/src/thrift/transport/TTransport.h
@@ -21,6 +21,7 @@
#define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1
#include <thrift/Thrift.h>
+#include <thrift/TConfiguration.h>
#include <thrift/transport/TTransportException.h>
#include <memory>
#include <string>
@@ -55,6 +56,15 @@ uint32_t readAll(Transport_& trans, uint8_t* buf, uint32_t len) {
*/
class TTransport {
public:
+ TTransport(std::shared_ptr<TConfiguration> config = nullptr) {
+ if(config == nullptr) {
+ configuration_ = std::shared_ptr<TConfiguration> (new TConfiguration());
+ } else {
+ configuration_ = config;
+ }
+ resetConsumedMessageSize();
+ }
+
/**
* Virtual deconstructor.
*/
@@ -238,11 +248,87 @@ public:
*/
virtual const std::string getOrigin() const { return "Unknown"; }
+ std::shared_ptr<TConfiguration> getConfiguration() { return configuration_; }
+
+ void setConfiguration(std::shared_ptr<TConfiguration> config) {
+ if (config != nullptr) configuration_ = config;
+ }
+
+ /**
+ * Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
+ * Will throw if we already consumed too many bytes or if the new size is larger than allowed.
+ *
+ * @param size real message size
+ */
+ void updateKnownMessageSize(long int size)
+ {
+ long int consumed = knownMessageSize_ - remainingMessageSize_;
+ resetConsumedMessageSize(size);
+ countConsumedMessageBytes(consumed);
+ }
+
+ /**
+ * Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data
+ *
+ * @param numBytes numBytes bytes of data
+ */
+ void checkReadBytesAvailable(long int numBytes)
+ {
+ if (remainingMessageSize_ < numBytes)
+ throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
+ }
+
protected:
+ std::shared_ptr<TConfiguration> configuration_;
+ long int remainingMessageSize_;
+ long int knownMessageSize_;
+
+ inline long int getRemainingMessageSize() { return remainingMessageSize_; }
+ inline void setRemainingMessageSize(long int remainingMessageSize) { remainingMessageSize_ = remainingMessageSize; }
+ inline int getMaxMessageSize() { return configuration_->getMaxMessageSize(); }
+ inline long int getKnownMessageSize() { return knownMessageSize_; }
+ void setKnownMessageSize(long int knownMessageSize) { knownMessageSize_ = knownMessageSize; }
+
+ /**
+ * Resets RemainingMessageSize to the configured maximum
+ *
+ * @param newSize configured size
+ */
+ void resetConsumedMessageSize(long newSize = -1)
+ {
+ // full reset
+ if (newSize < 0)
+ {
+ knownMessageSize_ = getMaxMessageSize();
+ remainingMessageSize_ = getMaxMessageSize();
+ return;
+ }
+
+ // update only: message size can shrink, but not grow
+ if (newSize > knownMessageSize_)
+ throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
+
+ knownMessageSize_ = newSize;
+ remainingMessageSize_ = newSize;
+ }
+
/**
- * Simple constructor.
+ * Consumes numBytes from the RemainingMessageSize.
+ *
+ * @param numBytes Consumes numBytes
*/
- TTransport() = default;
+ void countConsumedMessageBytes(long int numBytes)
+ {
+ if (remainingMessageSize_ >= numBytes)
+ {
+ remainingMessageSize_ -= numBytes;
+ }
+ else
+ {
+ remainingMessageSize_ = 0;
+ throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
+ }
+ }
};
/**
diff --git a/lib/cpp/src/thrift/transport/TTransportUtils.cpp b/lib/cpp/src/thrift/transport/TTransportUtils.cpp
index 69372f3e2..427a2e7c1 100644
--- a/lib/cpp/src/thrift/transport/TTransportUtils.cpp
+++ b/lib/cpp/src/thrift/transport/TTransportUtils.cpp
@@ -26,6 +26,7 @@ namespace thrift {
namespace transport {
uint32_t TPipedTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint32_t need = len;
// We don't have enough data yet
@@ -104,8 +105,9 @@ void TPipedTransport::flush() {
TPipedFileReaderTransport::TPipedFileReaderTransport(
std::shared_ptr<TFileReaderTransport> srcTrans,
- std::shared_ptr<TTransport> dstTrans)
- : TPipedTransport(srcTrans, dstTrans), srcTrans_(srcTrans) {
+ std::shared_ptr<TTransport> dstTrans,
+ std::shared_ptr<TConfiguration> config)
+ : TPipedTransport(srcTrans, dstTrans, config), srcTrans_(srcTrans) {
}
TPipedFileReaderTransport::~TPipedFileReaderTransport() = default;
@@ -131,6 +133,7 @@ uint32_t TPipedFileReaderTransport::read(uint8_t* buf, uint32_t len) {
}
uint32_t TPipedFileReaderTransport::readAll(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint32_t have = 0;
uint32_t get = 0;
diff --git a/lib/cpp/src/thrift/transport/TTransportUtils.h b/lib/cpp/src/thrift/transport/TTransportUtils.h
index 28c93d2a1..68c25f4c9 100644
--- a/lib/cpp/src/thrift/transport/TTransportUtils.h
+++ b/lib/cpp/src/thrift/transport/TTransportUtils.h
@@ -63,8 +63,10 @@ public:
*/
class TPipedTransport : virtual public TTransport {
public:
- TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans)
- : srcTrans_(srcTrans),
+ TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TTransport(config),
+ srcTrans_(srcTrans),
dstTrans_(dstTrans),
rBufSize_(512),
rPos_(0),
@@ -88,8 +90,10 @@ public:
TPipedTransport(std::shared_ptr<TTransport> srcTrans,
std::shared_ptr<TTransport> dstTrans,
- uint32_t sz)
- : srcTrans_(srcTrans),
+ uint32_t sz,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TTransport(config),
+ srcTrans_(srcTrans),
dstTrans_(dstTrans),
rBufSize_(512),
rPos_(0),
@@ -241,7 +245,8 @@ protected:
class TPipedFileReaderTransport : public TPipedTransport, public TFileReaderTransport {
public:
TPipedFileReaderTransport(std::shared_ptr<TFileReaderTransport> srcTrans,
- std::shared_ptr<TTransport> dstTrans);
+ std::shared_ptr<TTransport> dstTrans,
+ std::shared_ptr<TConfiguration> config = nullptr);
~TPipedFileReaderTransport() override;
diff --git a/lib/cpp/src/thrift/transport/TVirtualTransport.h b/lib/cpp/src/thrift/transport/TVirtualTransport.h
index 0a0485742..44bfa1315 100644
--- a/lib/cpp/src/thrift/transport/TVirtualTransport.h
+++ b/lib/cpp/src/thrift/transport/TVirtualTransport.h
@@ -57,7 +57,7 @@ public:
void consume(uint32_t len) { this->TTransport::consume_virt(len); }
protected:
- TTransportDefaults() = default;
+ TTransportDefaults(std::shared_ptr<TConfiguration> config = nullptr) : TTransport(config) {}
};
/**
@@ -118,7 +118,7 @@ public:
}
protected:
- TVirtualTransport() = default;
+ TVirtualTransport() : Super_() {}
/*
* Templatized constructors, to allow arguments to be passed to the Super_
diff --git a/lib/cpp/src/thrift/transport/TWebSocketServer.h b/lib/cpp/src/thrift/transport/TWebSocketServer.h
index 2e94c839d..7f39f36b9 100644
--- a/lib/cpp/src/thrift/transport/TWebSocketServer.h
+++ b/lib/cpp/src/thrift/transport/TWebSocketServer.h
@@ -53,8 +53,8 @@ std::string base64Encode(unsigned char* data, int length);
template <bool binary>
class TWebSocketServer : public THttpServer {
public:
- TWebSocketServer(std::shared_ptr<TTransport> transport)
- : THttpServer(transport) {
+ TWebSocketServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
+ : THttpServer(transport, config) {
resetHandshake();
}
@@ -98,6 +98,7 @@ public:
}
void flush() override {
+ resetConsumedMessageSize();
writeFrameHeader();
uint8_t* buffer;
uint32_t length;
diff --git a/lib/cpp/src/thrift/transport/TZlibTransport.cpp b/lib/cpp/src/thrift/transport/TZlibTransport.cpp
index b4c43d647..657ce5205 100644
--- a/lib/cpp/src/thrift/transport/TZlibTransport.cpp
+++ b/lib/cpp/src/thrift/transport/TZlibTransport.cpp
@@ -136,6 +136,7 @@ inline int TZlibTransport::readAvail() const {
}
uint32_t TZlibTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint32_t need = len;
// TODO(dreiss): Skip urbuf on big reads.
@@ -265,6 +266,7 @@ void TZlibTransport::flush() {
}
flushToTransport(Z_FULL_FLUSH);
+ resetConsumedMessageSize();
}
void TZlibTransport::finish() {
@@ -335,6 +337,7 @@ const uint8_t* TZlibTransport::borrow(uint8_t* buf, uint32_t* len) {
}
void TZlibTransport::consume(uint32_t len) {
+ countConsumedMessageBytes(len);
if (readAvail() >= (int)len) {
urpos_ += len;
} else {
diff --git a/lib/cpp/src/thrift/transport/TZlibTransport.h b/lib/cpp/src/thrift/transport/TZlibTransport.h
index 4990afff5..85765e6be 100644
--- a/lib/cpp/src/thrift/transport/TZlibTransport.h
+++ b/lib/cpp/src/thrift/transport/TZlibTransport.h
@@ -83,8 +83,10 @@ public:
int crbuf_size = DEFAULT_CRBUF_SIZE,
int uwbuf_size = DEFAULT_UWBUF_SIZE,
int cwbuf_size = DEFAULT_CWBUF_SIZE,
- int16_t comp_level = Z_DEFAULT_COMPRESSION)
- : transport_(transport),
+ int16_t comp_level = Z_DEFAULT_COMPRESSION,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
urpos_(0),
uwpos_(0),
input_ended_(false),
diff --git a/lib/cpp/test/CMakeLists.txt b/lib/cpp/test/CMakeLists.txt
index 48e2fd375..ced78a257 100644
--- a/lib/cpp/test/CMakeLists.txt
+++ b/lib/cpp/test/CMakeLists.txt
@@ -81,6 +81,7 @@ set(UnitTest_SOURCES
TypedefTest.cpp
TServerSocketTest.cpp
TServerTransportTest.cpp
+ ThrifttReadCheckTests.cpp
)
add_executable(UnitTests ${UnitTest_SOURCES})
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index 89826839d..7f630db10 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -130,7 +130,8 @@ UnitTests_SOURCES = \
TypedefTest.cpp \
TServerSocketTest.cpp \
TServerTransportTest.cpp \
- TTransportCheckThrow.h
+ TTransportCheckThrow.h \
+ ThrifttReadCheckTests.cpp
UnitTests_LDADD = \
libtestgencpp.la \
diff --git a/lib/cpp/test/ThrifttReadCheckTests.cpp b/lib/cpp/test/ThrifttReadCheckTests.cpp
new file mode 100644
index 000000000..4a594e6ca
--- /dev/null
+++ b/lib/cpp/test/ThrifttReadCheckTests.cpp
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#define MAX_MESSAGE_SIZE 2
+
+#include <boost/test/auto_unit_test.hpp>
+#include <boost/test/unit_test.hpp>
+#include <iostream>
+#include <climits>
+#include <vector>
+#include <thrift/TConfiguration.h>
+#include <thrift/protocol/TBinaryProtocol.h>
+#include <thrift/protocol/TCompactProtocol.h>
+#include <thrift/protocol/TJSONProtocol.h>
+#include <thrift/Thrift.h>
+#include <memory>
+#include <thrift/transport/TTransportUtils.h>
+#include <thrift/transport/TBufferTransports.h>
+#include <thrift/transport/TSimpleFileTransport.h>
+#include <thrift/transport/TFileTransport.h>
+#include <thrift/protocol/TEnum.h>
+#include <thrift/protocol/TList.h>
+#include <thrift/protocol/TSet.h>
+#include <thrift/protocol/TMap.h>
+
+BOOST_AUTO_TEST_SUITE(ThriftReadCheckExceptionTest)
+
+using apache::thrift::TConfiguration;
+using apache::thrift::protocol::TBinaryProtocol;
+using apache::thrift::protocol::TCompactProtocol;
+using apache::thrift::protocol::TJSONProtocol;
+using apache::thrift::protocol::TType;
+using apache::thrift::transport::TPipedTransport;
+using apache::thrift::transport::TMemoryBuffer;
+using apache::thrift::transport::TSimpleFileTransport;
+using apache::thrift::transport::TFileTransport;
+using apache::thrift::transport::TFDTransport;
+using apache::thrift::transport::TTransportException;
+using apache::thrift::transport::TBufferedTransport;
+using apache::thrift::transport::TFramedTransport;
+using std::shared_ptr;
+using std::cout;
+using std::endl;
+using std::string;
+using std::memset;
+using namespace apache::thrift;
+using namespace apache::thrift::protocol;
+
+
+BOOST_AUTO_TEST_CASE(test_tmemorybuffer_read_check_exception) {
+ std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
+ TMemoryBuffer trans_out(config);
+ uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
+ trans_out.write((const uint8_t*)buffer, sizeof(buffer));
+ trans_out.close();
+
+ TMemoryBuffer trans_in(config);
+ memset(buffer, 0, sizeof(buffer));
+ BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
+ trans_in.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_tpipedtransport_read_check_exception) {
+ std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TMemoryBuffer> pipe(new TMemoryBuffer);
+ std::shared_ptr<TMemoryBuffer> underlying(new TMemoryBuffer);
+ std::shared_ptr<TPipedTransport> trans(new TPipedTransport(underlying, pipe, config));
+
+ uint8_t buffer[4];
+
+ underlying->write((uint8_t*)"abcd", 4);
+ BOOST_CHECK_THROW(trans->read(buffer, sizeof(buffer)), TTransportException*);
+ BOOST_CHECK_THROW(trans->readAll(buffer, sizeof(buffer)), TTransportException*);
+ trans->readEnd();
+ pipe->resetBuffer();
+ underlying->write((uint8_t*)"ef", 2);
+ BOOST_CHECK_THROW(trans->read(buffer, sizeof(buffer)), TTransportException*);
+ BOOST_CHECK_THROW(trans->readAll(buffer, sizeof(buffer)), TTransportException*);
+ trans->readEnd();
+}
+
+BOOST_AUTO_TEST_CASE(test_tsimplefiletransport_read_check_exception) {
+ std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
+ TSimpleFileTransport trans_out("data", false, true, config);
+ uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
+ trans_out.write((const uint8_t*)buffer, sizeof(buffer));
+ trans_out.close();
+
+ TSimpleFileTransport trans_in("data",true, false, config);
+ memset(buffer, 0, sizeof(buffer));
+ BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
+ trans_in.close();
+
+ remove("./data");
+}
+
+BOOST_AUTO_TEST_CASE(test_tfiletransport_read_check_exception) {
+ std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
+ TFileTransport trans_out("data", false, config);
+ uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
+ trans_out.write((const uint8_t*)buffer, sizeof(buffer));
+
+ TFileTransport trans_in("data", false, config);
+ memset(buffer, 0, sizeof(buffer));
+ BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
+
+ remove("./data");
+}
+
+BOOST_AUTO_TEST_CASE(test_tbufferedtransport_read_check_exception) {
+ uint8_t arr[4] = {1, 2, 3, 4};
+ std::shared_ptr<TMemoryBuffer> buffer (new TMemoryBuffer(arr, sizeof(arr)));
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TBufferedTransport> trans (new TBufferedTransport(buffer, config));
+
+ trans->write((const uint8_t*)arr, sizeof(arr));
+ BOOST_CHECK_THROW(trans->read(arr, sizeof(arr)), TTransportException*);
+}
+
+BOOST_AUTO_TEST_CASE(test_tframedtransport_read_check_exception) {
+ uint8_t arr[4] = {1, 2, 3, 4};
+ std::shared_ptr<TMemoryBuffer> buffer (new TMemoryBuffer(arr, sizeof(arr)));
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TFramedTransport> trans (new TFramedTransport(buffer, config));
+
+ trans->write((const uint8_t*)arr, sizeof(arr));
+ BOOST_CHECK_THROW(trans->read(arr, sizeof(arr)), TTransportException*);
+}
+
+BOOST_AUTO_TEST_CASE(test_tthriftbinaryprotocol_read_check_exception) {
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
+ std::shared_ptr<TBinaryProtocol> protocol(new TBinaryProtocol(transport));
+
+ uint32_t val = 0;
+ TType elemType = apache::thrift::protocol::T_STOP;
+ TType elemType1 = apache::thrift::protocol::T_STOP;
+ TList list(T_I32, 8);
+ protocol->writeListBegin(list.elemType_, list.size_);
+ protocol->writeListEnd();
+ BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
+ protocol->readListEnd();
+
+ TSet set(T_I32, 8);
+ protocol->writeSetBegin(set.elemType_, set.size_);
+ protocol->writeSetEnd();
+ BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
+ protocol->readSetEnd();
+
+ TMap map(T_I32, T_I32, 8);
+ protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
+ protocol->writeMapEnd();
+ BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
+ protocol->readMapEnd();
+}
+
+BOOST_AUTO_TEST_CASE(test_tthriftcompactprotocol_read_check_exception) {
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
+ std::shared_ptr<TCompactProtocol> protocol(new TCompactProtocol(transport));
+
+ uint32_t val = 0;
+ TType elemType = apache::thrift::protocol::T_STOP;
+ TType elemType1 = apache::thrift::protocol::T_STOP;
+ TList list(T_I32, 8);
+ protocol->writeListBegin(list.elemType_, list.size_);
+ protocol->writeListEnd();
+ BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
+ protocol->readListEnd();
+
+ TSet set(T_I32, 8);
+ protocol->writeSetBegin(set.elemType_, set.size_);
+ protocol->writeSetEnd();
+ BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
+ protocol->readSetEnd();
+
+ TMap map(T_I32, T_I32, 8);
+ protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
+ protocol->writeMapEnd();
+ BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
+ protocol->readMapEnd();
+}
+
+BOOST_AUTO_TEST_CASE(test_tthriftjsonprotocol_read_check_exception) {
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
+ std::shared_ptr<TJSONProtocol> protocol(new TJSONProtocol(transport));
+
+ uint32_t val = 0;
+ TType elemType = apache::thrift::protocol::T_STOP;
+ TType elemType1 = apache::thrift::protocol::T_STOP;
+ TList list(T_I32, 8);
+ protocol->writeListBegin(list.elemType_, list.size_);
+ protocol->writeListEnd();
+ BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
+ protocol->readListEnd();
+
+ TSet set(T_I32, 8);
+ protocol->writeSetBegin(set.elemType_, set.size_);
+ protocol->writeSetEnd();
+ BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
+ protocol->readSetEnd();
+
+ TMap map(T_I32, T_I32, 8);
+ protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
+ protocol->writeMapEnd();
+ BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
+ protocol->readMapEnd();
+}
+
+BOOST_AUTO_TEST_SUITE_END()