diff options
author | zeshuai007 <51382517@qq.com> | 2020-06-15 17:00:33 +0800 |
---|---|---|
committer | Jens Geyer <jensg@apache.org> | 2020-07-25 12:13:53 +0200 |
commit | 86352b4821085d63861deab59c46ef1042fbfe81 (patch) | |
tree | 6c9c441d4125e4bb115e9989a769c99b36212677 | |
parent | 23c8e52fa0708c53f74958944ecf04b293d1db73 (diff) | |
download | thrift-86352b4821085d63861deab59c46ef1042fbfe81.tar.gz |
THRIFT-5237 Implement MAX_MESSAGE_SIZE and consolidate limits into a TConfiguration class
Client: cpp
Patch: Zezeng Wang
This closes #2185
46 files changed, 1004 insertions, 151 deletions
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am index a536d1719..3a0c4e63b 100755 --- a/lib/cpp/Makefile.am +++ b/lib/cpp/Makefile.am @@ -141,7 +141,8 @@ include_thrift_HEADERS = \ src/thrift/TApplicationException.h \ src/thrift/TLogging.h \ src/thrift/TToString.h \ - src/thrift/TBase.h + src/thrift/TBase.h \ + src/thrift/TConfiguration.h include_concurrencydir = $(include_thriftdir)/concurrency include_concurrency_HEADERS = \ @@ -156,6 +157,10 @@ include_concurrency_HEADERS = \ include_protocoldir = $(include_thriftdir)/protocol include_protocol_HEADERS = \ + src/thrift/protocol/TEnum.h \ + src/thrift/protocol/TList.h \ + src/thrift/protocol/TSet.h \ + src/thrift/protocol/TMap.h \ src/thrift/protocol/TBinaryProtocol.h \ src/thrift/protocol/TBinaryProtocol.tcc \ src/thrift/protocol/TCompactProtocol.h \ diff --git a/lib/cpp/src/thrift/TConfiguration.h b/lib/cpp/src/thrift/TConfiguration.h new file mode 100644 index 000000000..5bff440a0 --- /dev/null +++ b/lib/cpp/src/thrift/TConfiguration.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef THRIFT_TCONFIGURATION_H +#define THRIFT_TCONFIGURATION_H + +namespace apache { +namespace thrift { + +class TConfiguration +{ +public: + TConfiguration(int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE, + int maxFrameSize = DEFAULT_MAX_FRAME_SIZE, int recursionLimit = DEFAULT_RECURSION_DEPTH) + : maxMessageSize_(maxMessageSize), maxFrameSize_(maxFrameSize), recursionLimit_(recursionLimit) {} + + const static int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024; + const static int DEFAULT_MAX_FRAME_SIZE = 16384000; // this value is used consistently across all Thrift libraries + const static int DEFAULT_RECURSION_DEPTH = 64; + + inline int getMaxMessageSize() { return maxMessageSize_; } + inline void setMaxMessageSize(int maxMessageSize) { maxMessageSize_ = maxMessageSize; } + inline int getMaxFrameSize() { return maxFrameSize_; } + inline void setMaxFrameSize(int maxFrameSize) { maxFrameSize_ = maxFrameSize; } + inline int getRecursionLimit() { return recursionLimit_; } + inline void setRecursionLimit(int recursionLimit) { recursionLimit_ = recursionLimit; } + +private: + int maxMessageSize_ = DEFAULT_MAX_MESSAGE_SIZE; + int maxFrameSize_ = DEFAULT_MAX_FRAME_SIZE; + int recursionLimit_ = DEFAULT_RECURSION_DEPTH; + + // TODO(someone_smart): add connection and i/o timeouts +}; +} +} // apache::thrift + +#endif /* THRIFT_TCONFIGURATION_H */ + diff --git a/lib/cpp/src/thrift/protocol/TBinaryProtocol.h b/lib/cpp/src/thrift/protocol/TBinaryProtocol.h index 6bd5fb830..b43144017 100644 --- a/lib/cpp/src/thrift/protocol/TBinaryProtocol.h +++ b/lib/cpp/src/thrift/protocol/TBinaryProtocol.h @@ -166,6 +166,24 @@ public: inline uint32_t readBinary(std::string& str); + int getMinSerializedSize(TType type); + + void checkReadBytesAvailable(TSet& set) + { + trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_)); + } + + void checkReadBytesAvailable(TList& list) + { + trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_)); + } + + void checkReadBytesAvailable(TMap& map) + { + int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_); + trans_->checkReadBytesAvailable(map.size_ * elmSize); + } + protected: template <typename StrType> uint32_t readStringBody(StrType& str, int32_t sz); diff --git a/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc b/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc index 2964f25d0..755f24386 100644 --- a/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc +++ b/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc @@ -21,6 +21,7 @@ #define _THRIFT_PROTOCOL_TBINARYPROTOCOL_TCC_ 1 #include <thrift/protocol/TBinaryProtocol.h> +#include <thrift/transport/TTransportException.h> #include <limits> @@ -285,6 +286,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readMapBegin(TType& keyType, throw TProtocolException(TProtocolException::SIZE_LIMIT); } size = (uint32_t)sizei; + + TMap map(keyType, valType, size); + checkReadBytesAvailable(map); + return result; } @@ -307,6 +312,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readListBegin(TType& elemType throw TProtocolException(TProtocolException::SIZE_LIMIT); } size = (uint32_t)sizei; + + TList list(elemType, size); + checkReadBytesAvailable(list); + return result; } @@ -329,6 +338,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readSetBegin(TType& elemType, throw TProtocolException(TProtocolException::SIZE_LIMIT); } size = (uint32_t)sizei; + + TSet set(elemType, size); + checkReadBytesAvailable(set); + return result; } @@ -447,6 +460,30 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readStringBody(StrType& str, this->trans_->readAll(reinterpret_cast<uint8_t*>(&str[0]), size); return (uint32_t)size; } + +// Return the minimum number of bytes a type will consume on the wire +template <class Transport_, class ByteOrder_> +int TBinaryProtocolT<Transport_, ByteOrder_>::getMinSerializedSize(TType type) +{ + switch (type) + { + case T_STOP: return 0; + case T_VOID: return 0; + case T_BOOL: return sizeof(int8_t); + case T_BYTE: return sizeof(int8_t); + case T_DOUBLE: return sizeof(double); + case T_I16: return sizeof(short); + case T_I32: return sizeof(int); + case T_I64: return sizeof(long); + case T_STRING: return sizeof(int); // string length + case T_STRUCT: return 0; // empty struct + case T_MAP: return sizeof(int); // element count + case T_SET: return sizeof(int); // element count + case T_LIST: return sizeof(int); // element count + default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code"); + } +} + } } } // apache::thrift::protocol diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.h b/lib/cpp/src/thrift/protocol/TCompactProtocol.h index 2930aba29..6f990b2d6 100644 --- a/lib/cpp/src/thrift/protocol/TCompactProtocol.h +++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.h @@ -140,6 +140,24 @@ public: uint32_t writeBinary(const std::string& str); + int getMinSerializedSize(TType type); + + void checkReadBytesAvailable(TSet& set) + { + trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_)); + } + + void checkReadBytesAvailable(TList& list) + { + trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_)); + } + + void checkReadBytesAvailable(TMap& map) + { + int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_); + trans_->checkReadBytesAvailable(map.size_ * elmSize); + } + /** * These methods are called by structs, but don't actually have any wired * output or purpose diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc b/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc index d1e342efd..16780911c 100644 --- a/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc +++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc @@ -538,6 +538,9 @@ uint32_t TCompactProtocolT<Transport_>::readMapBegin(TType& keyType, valType = getTType((int8_t)((uint8_t)kvType & 0xf)); size = (uint32_t)msize; + TMap map(keyType, valType, size); + checkReadBytesAvailable(map); + return rsize; } @@ -570,6 +573,9 @@ uint32_t TCompactProtocolT<Transport_>::readListBegin(TType& elemType, elemType = getTType((int8_t)(size_and_type & 0x0f)); size = (uint32_t)lsize; + TList list(elemType, size); + checkReadBytesAvailable(list); + return rsize; } @@ -706,6 +712,8 @@ uint32_t TCompactProtocolT<Transport_>::readBinary(std::string& str) { trans_->readAll(string_buf_, size); str.assign((char*)string_buf_, size); + trans_->checkReadBytesAvailable(rsize + (uint32_t)size); + return rsize + (uint32_t)size; } @@ -821,6 +829,30 @@ TType TCompactProtocolT<Transport_>::getTType(int8_t type) { } } +// Return the minimum number of bytes a type will consume on the wire +template <class Transport_> +int TCompactProtocolT<Transport_>::getMinSerializedSize(TType type) +{ + switch (type) + { + case T_STOP: return 0; + case T_VOID: return 0; + case T_BOOL: return sizeof(int8_t); + case T_DOUBLE: return 8; // uses fixedLongToBytes() which always writes 8 bytes + case T_BYTE: return sizeof(int8_t); + case T_I16: return sizeof(int8_t); // zigzag + case T_I32: return sizeof(int8_t); // zigzag + case T_I64: return sizeof(int8_t); // zigzag + case T_STRING: return sizeof(int8_t); // string length + case T_STRUCT: return 0; // empty struct + case T_MAP: return sizeof(int8_t); // element count + case T_SET: return sizeof(int8_t); // element count + case T_LIST: return sizeof(int8_t); // element count + default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code"); + } +} + + }}} // apache::thrift::protocol #endif // _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_TCC_ diff --git a/lib/cpp/src/thrift/protocol/TEnum.h b/lib/cpp/src/thrift/protocol/TEnum.h new file mode 100644 index 000000000..9636785e3 --- /dev/null +++ b/lib/cpp/src/thrift/protocol/TEnum.h @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_ENUM_H_ +#define _THRIFT_ENUM_H_ + +namespace apache { +namespace thrift { +namespace protocol { + +/** + * Enumerated definition of the types that the Thrift protocol supports. + * Take special note of the T_END type which is used specifically to mark + * the end of a sequence of fields. + */ +enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +}; + +/** + * Enumerated definition of the message types that the Thrift protocol + * supports. + */ +enum TMessageType { + T_CALL = 1, + T_REPLY = 2, + T_EXCEPTION = 3, + T_ONEWAY = 4 +}; + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_ENUM_H_ diff --git a/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp b/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp index 28d0da299..6e4e8ef0d 100644 --- a/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp +++ b/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp @@ -1013,6 +1013,10 @@ uint32_t TJSONProtocol::readMapBegin(TType& keyType, TType& valType, uint32_t& s throw TProtocolException(TProtocolException::SIZE_LIMIT); size = static_cast<uint32_t>(tmpVal); result += readJSONObjectStart(); + + TMap map(keyType, valType, size); + checkReadBytesAvailable(map); + return result; } @@ -1032,6 +1036,10 @@ uint32_t TJSONProtocol::readListBegin(TType& elemType, uint32_t& size) { if (tmpVal > (std::numeric_limits<uint32_t>::max)()) throw TProtocolException(TProtocolException::SIZE_LIMIT); size = static_cast<uint32_t>(tmpVal); + + TList list(elemType, size); + checkReadBytesAvailable(list); + return result; } @@ -1049,6 +1057,10 @@ uint32_t TJSONProtocol::readSetBegin(TType& elemType, uint32_t& size) { if (tmpVal > (std::numeric_limits<uint32_t>::max)()) throw TProtocolException(TProtocolException::SIZE_LIMIT); size = static_cast<uint32_t>(tmpVal); + + TSet set(elemType, size); + checkReadBytesAvailable(set); + return result; } @@ -1093,6 +1105,29 @@ uint32_t TJSONProtocol::readString(std::string& str) { uint32_t TJSONProtocol::readBinary(std::string& str) { return readJSONBase64(str); } + +// Return the minimum number of bytes a type will consume on the wire +int TJSONProtocol::getMinSerializedSize(TType type) +{ + switch (type) + { + case T_STOP: return 0; + case T_VOID: return 0; + case T_BOOL: return 1; // written as int + case T_BYTE: return 1; + case T_DOUBLE: return 1; + case T_I16: return 1; + case T_I32: return 1; + case T_I64: return 1; + case T_STRING: return 2; // empty string + case T_STRUCT: return 2; // empty struct + case T_MAP: return 2; // empty map + case T_SET: return 2; // empty set + case T_LIST: return 2; // empty list + default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code"); + } +} + } } } // apache::thrift::protocol diff --git a/lib/cpp/src/thrift/protocol/TJSONProtocol.h b/lib/cpp/src/thrift/protocol/TJSONProtocol.h index 420995ef3..e775240ab 100644 --- a/lib/cpp/src/thrift/protocol/TJSONProtocol.h +++ b/lib/cpp/src/thrift/protocol/TJSONProtocol.h @@ -245,6 +245,24 @@ public: uint32_t readBinary(std::string& str); + int getMinSerializedSize(TType type); + + void checkReadBytesAvailable(TSet& set) + { + trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_)); + } + + void checkReadBytesAvailable(TList& list) + { + trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_)); + } + + void checkReadBytesAvailable(TMap& map) + { + int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_); + trans_->checkReadBytesAvailable(map.size_ * elmSize); + } + class LookaheadReader { public: diff --git a/lib/cpp/src/thrift/protocol/TList.h b/lib/cpp/src/thrift/protocol/TList.h new file mode 100644 index 000000000..bf2c1f9de --- /dev/null +++ b/lib/cpp/src/thrift/protocol/TList.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TLIST_H_ +#define _THRIFT_TLIST_H_ + +#include <thrift/protocol/TEnum.h> + +namespace apache { +namespace thrift { +namespace protocol { + +// using namespace apache::thrift::protocol; + +/** + * Helper class that encapsulates list metadata. + * + */ +class TList { +public: + TList() : elemType_(T_STOP), + size_(0) { + + } + + TList(TType t = T_STOP, int s = 0) + : elemType_(t), + size_(s) { + + } + + TType elemType_; + int size_; +}; +} +} +} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_TLIST_H_ diff --git a/lib/cpp/src/thrift/protocol/TMap.h b/lib/cpp/src/thrift/protocol/TMap.h new file mode 100644 index 000000000..b52ea8faf --- /dev/null +++ b/lib/cpp/src/thrift/protocol/TMap.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TMAP_H_ +#define _THRIFT_TMAP_H_ + +#include <thrift/protocol/TEnum.h> + +namespace apache { +namespace thrift { +namespace protocol { + +using namespace apache::thrift::protocol; + +/** + * Helper class that encapsulates map metadata. + * + */ +class TMap { +public: + TMap() + : keyType_(T_STOP), + valueType_(T_STOP), + size_(0) { + + } + + TMap(TType k, TType v, int s) + : keyType_(k), + valueType_(v), + size_(s) { + + } + + TType keyType_; + TType valueType_; + int size_; +}; +} +} +} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_TMAP_H_ diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h index df9c5c39b..867ceb079 100644 --- a/lib/cpp/src/thrift/protocol/TProtocol.h +++ b/lib/cpp/src/thrift/protocol/TProtocol.h @@ -27,6 +27,10 @@ #include <thrift/transport/TTransport.h> #include <thrift/protocol/TProtocolException.h> +#include <thrift/protocol/TEnum.h> +#include <thrift/protocol/TList.h> +#include <thrift/protocol/TSet.h> +#include <thrift/protocol/TMap.h> #include <memory> @@ -171,45 +175,6 @@ namespace protocol { using apache::thrift::transport::TTransport; /** - * Enumerated definition of the types that the Thrift protocol supports. - * Take special note of the T_END type which is used specifically to mark - * the end of a sequence of fields. - */ -enum TType { - T_STOP = 0, - T_VOID = 1, - T_BOOL = 2, - T_BYTE = 3, - T_I08 = 3, - T_I16 = 6, - T_I32 = 8, - T_U64 = 9, - T_I64 = 10, - T_DOUBLE = 4, - T_STRING = 11, - T_UTF7 = 11, - T_STRUCT = 12, - T_MAP = 13, - T_SET = 14, - T_LIST = 15, - T_UTF8 = 16, - T_UTF16 = 17 -}; - -/** - * Enumerated definition of the message types that the Thrift protocol - * supports. - */ -enum TMessageType { - T_CALL = 1, - T_REPLY = 2, - T_EXCEPTION = 3, - T_ONEWAY = 4 -}; - -static const uint32_t DEFAULT_RECURSION_LIMIT = 64; - -/** * Abstract class for a thrift protocol driver. These are all the methods that * a protocol must implement. Essentially, there must be some way of reading * and writing all the base types, plus a mechanism for writing out structs @@ -578,11 +543,34 @@ public: uint32_t getRecursionLimit() const {return recursion_limit_;} void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;} + // Returns the minimum amount of bytes needed to store the smallest possible instance of TType. + virtual int getMinSerializedSize(TType type) { + THRIFT_UNUSED_VARIABLE(type); + return 0; + } + protected: TProtocol(std::shared_ptr<TTransport> ptrans) - : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT) + : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0), + recursion_limit_(ptrans->getConfiguration()->getRecursionLimit()) {} + virtual void checkReadBytesAvailable(TSet& set) + { + ptrans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_)); + } + + virtual void checkReadBytesAvailable(TList& list) + { + ptrans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_)); + } + + virtual void checkReadBytesAvailable(TMap& map) + { + int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_); + ptrans_->checkReadBytesAvailable(map.size_ * elmSize); + } + std::shared_ptr<TTransport> ptrans_; private: diff --git a/lib/cpp/src/thrift/protocol/TSet.h b/lib/cpp/src/thrift/protocol/TSet.h new file mode 100644 index 000000000..3a4718cdc --- /dev/null +++ b/lib/cpp/src/thrift/protocol/TSet.h @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TSET_H_ +#define _THRIFT_TSET_H_ + +#include <thrift/protocol/TEnum.h> +#include <thrift/protocol/TList.h> + +namespace apache { +namespace thrift { +namespace protocol { + +using namespace apache::thrift::protocol; + +/** + * Helper class that encapsulates set metadata. + * + */ +class TSet { +public: + TSet() : elemType_(T_STOP), size_(0) { + + } + + TSet(TType t, int s) + : elemType_(t), + size_(s) { + + } + + TSet(TList list) + : elemType_(list.elemType_), + size_(list.size_) { + + } + + TType elemType_; + int size_; +}; +} +} +} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_TSET_H_ diff --git a/lib/cpp/src/thrift/transport/TBufferTransports.cpp b/lib/cpp/src/thrift/transport/TBufferTransports.cpp index d8a1b3e28..45c0c9bf9 100644 --- a/lib/cpp/src/thrift/transport/TBufferTransports.cpp +++ b/lib/cpp/src/thrift/transport/TBufferTransports.cpp @@ -118,6 +118,7 @@ const uint8_t* TBufferedTransport::borrowSlow(uint8_t* buf, uint32_t* len) { } void TBufferedTransport::flush() { + resetConsumedMessageSize(); // Write out any data waiting in the write buffer. auto have_bytes = static_cast<uint32_t>(wBase_ - wBuf_.get()); if (have_bytes > 0) { @@ -248,6 +249,7 @@ void TFramedTransport::writeSlow(const uint8_t* buf, uint32_t len) { } void TFramedTransport::flush() { + resetConsumedMessageSize(); int32_t sz_hbo, sz_nbo; assert(wBufSize_ > sizeof(sz_nbo)); diff --git a/lib/cpp/src/thrift/transport/TBufferTransports.h b/lib/cpp/src/thrift/transport/TBufferTransports.h index 86f0c5acc..179934ba0 100644 --- a/lib/cpp/src/thrift/transport/TBufferTransports.h +++ b/lib/cpp/src/thrift/transport/TBufferTransports.h @@ -62,6 +62,7 @@ public: * This method is meant to eventually be nonvirtual and inlinable. */ uint32_t read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); uint8_t* new_rBase = rBase_ + len; if (TDB_LIKELY(new_rBase <= rBound_)) { std::memcpy(buf, rBase_, len); @@ -120,6 +121,7 @@ public: * Consume doesn't require a slow path. */ void consume(uint32_t len) { + countConsumedMessageBytes(len); if (TDB_LIKELY(static_cast<ptrdiff_t>(len) <= rBound_ - rBase_)) { rBase_ += len; } else { @@ -148,7 +150,8 @@ protected: * performance-sensitive operation, so it is okay to just leave it to * the concrete class to set up pointers correctly. */ - TBufferBase() : rBase_(nullptr), rBound_(nullptr), wBase_(nullptr), wBound_(nullptr) {} + TBufferBase(std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), rBase_(nullptr), rBound_(nullptr), wBase_(nullptr), wBound_(nullptr) {} /// Convenience mutator for setting the read buffer. void setReadBuffer(uint8_t* buf, uint32_t len) { @@ -186,8 +189,9 @@ public: static const int DEFAULT_BUFFER_SIZE = 512; /// Use default buffer sizes. - TBufferedTransport(std::shared_ptr<TTransport> transport) - : transport_(transport), + TBufferedTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), + transport_(transport), rBufSize_(DEFAULT_BUFFER_SIZE), wBufSize_(DEFAULT_BUFFER_SIZE), rBuf_(new uint8_t[rBufSize_]), @@ -196,8 +200,9 @@ public: } /// Use specified buffer sizes. - TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t sz) - : transport_(transport), + TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t sz, std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), + transport_(transport), rBufSize_(sz), wBufSize_(sz), rBuf_(new uint8_t[rBufSize_]), @@ -206,8 +211,10 @@ public: } /// Use specified read and write buffer sizes. - TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz) - : transport_(transport), + TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz, + std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), + transport_(transport), rBufSize_(rsz), wBufSize_(wsz), rBuf_(new uint8_t[rBufSize_]), @@ -309,8 +316,9 @@ public: static const int DEFAULT_MAX_FRAME_SIZE = 256 * 1024 * 1024; /// Use default buffer sizes. - TFramedTransport() - : transport_(), + TFramedTransport(std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), + transport_(), rBufSize_(0), wBufSize_(DEFAULT_BUFFER_SIZE), rBuf_(), @@ -319,27 +327,30 @@ public: initPointers(); } - TFramedTransport(std::shared_ptr<TTransport> transport) - : transport_(transport), + TFramedTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), + transport_(transport), rBufSize_(0), wBufSize_(DEFAULT_BUFFER_SIZE), rBuf_(), wBuf_(new uint8_t[wBufSize_]), bufReclaimThresh_((std::numeric_limits<uint32_t>::max)()), - maxFrameSize_(DEFAULT_MAX_FRAME_SIZE) { + maxFrameSize_(configuration_->getMaxFrameSize()) { initPointers(); } TFramedTransport(std::shared_ptr<TTransport> transport, uint32_t sz, - uint32_t bufReclaimThresh = (std::numeric_limits<uint32_t>::max)()) - : transport_(transport), + uint32_t bufReclaimThresh = (std::numeric_limits<uint32_t>::max)(), + std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), + transport_(transport), rBufSize_(0), wBufSize_(sz), rBuf_(), wBuf_(new uint8_t[wBufSize_]), bufReclaimThresh_(bufReclaimThresh), - maxFrameSize_(DEFAULT_MAX_FRAME_SIZE) { + maxFrameSize_(configuration_->getMaxFrameSize()) { initPointers(); } @@ -503,7 +514,10 @@ public: * Construct a TMemoryBuffer with a default-sized buffer, * owned by the TMemoryBuffer object. */ - TMemoryBuffer() { initCommon(nullptr, defaultSize, true, 0); } + TMemoryBuffer(std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config) { + initCommon(nullptr, defaultSize, true, 0); + } /** * Construct a TMemoryBuffer with a buffer of a specified size, @@ -511,7 +525,10 @@ public: * * @param sz The initial size of the buffer. */ - TMemoryBuffer(uint32_t sz) { initCommon(nullptr, sz, true, 0); } + TMemoryBuffer(uint32_t sz, std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config) { + initCommon(nullptr, sz, true, 0); + } /** * Construct a TMemoryBuffer with buf as its initial contents. @@ -523,7 +540,8 @@ public: * @param sz The size of @c buf. * @param policy See @link MemoryPolicy @endlink . */ - TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) { + TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE, std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config) { if (buf == nullptr && sz != 0) { throw TTransportException(TTransportException::BAD_ARGS, "TMemoryBuffer given null buffer with non-zero size."); diff --git a/lib/cpp/src/thrift/transport/TFDTransport.cpp b/lib/cpp/src/thrift/transport/TFDTransport.cpp index 93dd10021..fa7f0dabe 100644 --- a/lib/cpp/src/thrift/transport/TFDTransport.cpp +++ b/lib/cpp/src/thrift/transport/TFDTransport.cpp @@ -52,6 +52,7 @@ void TFDTransport::close() { } uint32_t TFDTransport::read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); unsigned int maxRetries = 5; // same as the TSocket default unsigned int retries = 0; while (true) { diff --git a/lib/cpp/src/thrift/transport/TFDTransport.h b/lib/cpp/src/thrift/transport/TFDTransport.h index a3cf51948..fb84c9d8d 100644 --- a/lib/cpp/src/thrift/transport/TFDTransport.h +++ b/lib/cpp/src/thrift/transport/TFDTransport.h @@ -40,8 +40,10 @@ class TFDTransport : public TVirtualTransport<TFDTransport> { public: enum ClosePolicy { NO_CLOSE_ON_DESTROY = 0, CLOSE_ON_DESTROY = 1 }; - TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY) - : fd_(fd), close_policy_(close_policy) {} + TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY, + std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), fd_(fd), close_policy_(close_policy) { + } ~TFDTransport() override { if (close_policy_ == CLOSE_ON_DESTROY) { diff --git a/lib/cpp/src/thrift/transport/TFileTransport.cpp b/lib/cpp/src/thrift/transport/TFileTransport.cpp index eaf2bc365..08372b3e2 100644 --- a/lib/cpp/src/thrift/transport/TFileTransport.cpp +++ b/lib/cpp/src/thrift/transport/TFileTransport.cpp @@ -63,8 +63,9 @@ using std::string; using namespace apache::thrift::protocol; using namespace apache::thrift::concurrency; -TFileTransport::TFileTransport(string path, bool readOnly) - : readState_(), +TFileTransport::TFileTransport(string path, bool readOnly, std::shared_ptr<TConfiguration> config) + : TTransport(config), + readState_(), readBuff_(nullptr), currentEvent_(nullptr), readBuffSize_(DEFAULT_READ_BUFF_SIZE), @@ -519,6 +520,7 @@ void TFileTransport::writerThread() { } void TFileTransport::flush() { + resetConsumedMessageSize(); // file must be open for writing for any flushing to take place if (!writerThread_.get()) { return; @@ -537,6 +539,7 @@ void TFileTransport::flush() { } uint32_t TFileTransport::readAll(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); uint32_t have = 0; uint32_t get = 0; @@ -568,6 +571,7 @@ bool TFileTransport::peek() { } uint32_t TFileTransport::read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); // check if there an event is ready to be read if (!currentEvent_) { currentEvent_ = readEvent(); diff --git a/lib/cpp/src/thrift/transport/TFileTransport.h b/lib/cpp/src/thrift/transport/TFileTransport.h index 0df5cf909..608cff184 100644 --- a/lib/cpp/src/thrift/transport/TFileTransport.h +++ b/lib/cpp/src/thrift/transport/TFileTransport.h @@ -173,7 +173,7 @@ public: */ class TFileTransport : public TFileReaderTransport, public TFileWriterTransport { public: - TFileTransport(std::string path, bool readOnly = false); + TFileTransport(std::string path, bool readOnly = false, std::shared_ptr<TConfiguration> config = nullptr); ~TFileTransport() override; // TODO: what is the correct behaviour for this? diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.cpp b/lib/cpp/src/thrift/transport/THeaderTransport.cpp index b582d8da7..b3b833389 100644 --- a/lib/cpp/src/thrift/transport/THeaderTransport.cpp +++ b/lib/cpp/src/thrift/transport/THeaderTransport.cpp @@ -415,6 +415,7 @@ void THeaderTransport::clearHeaders() { } void THeaderTransport::flush() { + resetConsumedMessageSize(); // Write out any data waiting in the write buffer. uint32_t haveBytes = getWriteBytes(); diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.h b/lib/cpp/src/thrift/transport/THeaderTransport.h index d1e9d4339..63a4ac880 100644 --- a/lib/cpp/src/thrift/transport/THeaderTransport.h +++ b/lib/cpp/src/thrift/transport/THeaderTransport.h @@ -74,8 +74,9 @@ public: static const int THRIFT_MAX_VARINT32_BYTES = 5; /// Use default buffer sizes. - explicit THeaderTransport(const std::shared_ptr<TTransport>& transport) - : TVirtualTransport(transport), + explicit THeaderTransport(const std::shared_ptr<TTransport>& transport, + std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(transport, config), outTransport_(transport), protoId(T_COMPACT_PROTOCOL), clientType(THRIFT_HEADER_CLIENT_TYPE), @@ -88,8 +89,9 @@ public: } THeaderTransport(const std::shared_ptr<TTransport> inTransport, - const std::shared_ptr<TTransport> outTransport) - : TVirtualTransport(inTransport), + const std::shared_ptr<TTransport> outTransport, + std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(inTransport, config), outTransport_(outTransport), protoId(T_COMPACT_PROTOCOL), clientType(THRIFT_HEADER_CLIENT_TYPE), diff --git a/lib/cpp/src/thrift/transport/THttpClient.cpp b/lib/cpp/src/thrift/transport/THttpClient.cpp index fdee787c6..ea2eb99af 100644 --- a/lib/cpp/src/thrift/transport/THttpClient.cpp +++ b/lib/cpp/src/thrift/transport/THttpClient.cpp @@ -34,12 +34,16 @@ namespace transport { THttpClient::THttpClient(std::shared_ptr<TTransport> transport, std::string host, - std::string path) - : THttpTransport(transport), host_(host), path_(path) { + std::string path, + std::shared_ptr<TConfiguration> config) + : THttpTransport(transport, config), + host_(host), + path_(path) { } -THttpClient::THttpClient(string host, int port, string path) - : THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port))), +THttpClient::THttpClient(string host, int port, string path, + std::shared_ptr<TConfiguration> config) + : THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port)), config), host_(host), path_(path) { } @@ -93,6 +97,7 @@ bool THttpClient::parseStatusLine(char* status) { } void THttpClient::flush() { + resetConsumedMessageSize(); // Fetch the contents of the write buffer uint8_t* buf; uint32_t len; diff --git a/lib/cpp/src/thrift/transport/THttpClient.h b/lib/cpp/src/thrift/transport/THttpClient.h index 81ddc56c5..f0d7e8b27 100644 --- a/lib/cpp/src/thrift/transport/THttpClient.h +++ b/lib/cpp/src/thrift/transport/THttpClient.h @@ -40,13 +40,16 @@ public: */ THttpClient(std::shared_ptr<TTransport> transport, std::string host = "localhost", - std::string path = "/service"); + std::string path = "/service", + std::shared_ptr<TConfiguration> config = nullptr); /** * @brief Constructor that will create a new socket transport using the host * and port. */ - THttpClient(std::string host, int port, std::string path = ""); + THttpClient(std::string host, int port, + std::string path = "", + std::shared_ptr<TConfiguration> config = nullptr); ~THttpClient() override; diff --git a/lib/cpp/src/thrift/transport/THttpServer.cpp b/lib/cpp/src/thrift/transport/THttpServer.cpp index 98518fd55..91a1c39af 100644 --- a/lib/cpp/src/thrift/transport/THttpServer.cpp +++ b/lib/cpp/src/thrift/transport/THttpServer.cpp @@ -34,7 +34,9 @@ namespace apache { namespace thrift { namespace transport { -THttpServer::THttpServer(std::shared_ptr<TTransport> transport) : THttpTransport(transport) { +THttpServer::THttpServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config) + : THttpTransport(transport, config) { + } THttpServer::~THttpServer() = default; @@ -118,6 +120,7 @@ bool THttpServer::parseStatusLine(char* status) { } void THttpServer::flush() { + resetConsumedMessageSize(); // Fetch the contents of the write buffer uint8_t* buf; uint32_t len; diff --git a/lib/cpp/src/thrift/transport/THttpServer.h b/lib/cpp/src/thrift/transport/THttpServer.h index d2196911c..bc98986d7 100644 --- a/lib/cpp/src/thrift/transport/THttpServer.h +++ b/lib/cpp/src/thrift/transport/THttpServer.h @@ -28,7 +28,7 @@ namespace transport { class THttpServer : public THttpTransport { public: - THttpServer(std::shared_ptr<TTransport> transport); + THttpServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr); ~THttpServer() override; diff --git a/lib/cpp/src/thrift/transport/THttpTransport.cpp b/lib/cpp/src/thrift/transport/THttpTransport.cpp index aea2b2847..305221e9d 100644 --- a/lib/cpp/src/thrift/transport/THttpTransport.cpp +++ b/lib/cpp/src/thrift/transport/THttpTransport.cpp @@ -31,8 +31,9 @@ namespace transport { const char* THttpTransport::CRLF = "\r\n"; const int THttpTransport::CRLF_LEN = 2; -THttpTransport::THttpTransport(std::shared_ptr<TTransport> transport) - : transport_(transport), +THttpTransport::THttpTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config) + : TVirtualTransport(config), + transport_(transport), origin_(""), readHeaders_(true), chunked_(false), @@ -61,6 +62,7 @@ THttpTransport::~THttpTransport() { } uint32_t THttpTransport::read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); if (readBuffer_.available_read() == 0) { readBuffer_.resetBuffer(); uint32_t got = readMoreData(); diff --git a/lib/cpp/src/thrift/transport/THttpTransport.h b/lib/cpp/src/thrift/transport/THttpTransport.h index 75f0d8c07..5d2bd37fe 100644 --- a/lib/cpp/src/thrift/transport/THttpTransport.h +++ b/lib/cpp/src/thrift/transport/THttpTransport.h @@ -36,7 +36,7 @@ namespace transport { */ class THttpTransport : public TVirtualTransport<THttpTransport> { public: - THttpTransport(std::shared_ptr<TTransport> transport); + THttpTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr); ~THttpTransport() override; @@ -54,7 +54,9 @@ public: void write(const uint8_t* buf, uint32_t len); - void flush() override = 0; + void flush() override { + resetConsumedMessageSize(); + }; const std::string getOrigin() const override; diff --git a/lib/cpp/src/thrift/transport/TPipe.cpp b/lib/cpp/src/thrift/transport/TPipe.cpp index 4c2fea927..953cec167 100644 --- a/lib/cpp/src/thrift/transport/TPipe.cpp +++ b/lib/cpp/src/thrift/transport/TPipe.cpp @@ -222,30 +222,35 @@ uint32_t pseudo_sync_read(HANDLE pipe, HANDLE event, uint8_t* buf, uint32_t len) } //---- Constructors ---- -TPipe::TPipe(TAutoHandle &Pipe) - : impl_(new TWaitableNamedPipeImpl(Pipe)), TimeoutSeconds_(3), isAnonymous_(false) { +TPipe::TPipe(TAutoHandle &Pipe, std::shared_ptr<TConfiguration> config) + : impl_(new TWaitableNamedPipeImpl(Pipe)), TimeoutSeconds_(3), + isAnonymous_(false), TVirtualTransport(config) { } -TPipe::TPipe(HANDLE Pipe) - : TimeoutSeconds_(3), isAnonymous_(false) +TPipe::TPipe(HANDLE Pipe, std::shared_ptr<TConfiguration> config) + : TimeoutSeconds_(3), isAnonymous_(false), TVirtualTransport(config) { TAutoHandle pipeHandle(Pipe); impl_.reset(new TWaitableNamedPipeImpl(pipeHandle)); } -TPipe::TPipe(const char* pipename) : TimeoutSeconds_(3), isAnonymous_(false) { +TPipe::TPipe(const char* pipename, std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3), + isAnonymous_(false), TVirtualTransport(config) { setPipename(pipename); } -TPipe::TPipe(const std::string& pipename) : TimeoutSeconds_(3), isAnonymous_(false) { +TPipe::TPipe(const std::string& pipename, std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3), + isAnonymous_(false), TVirtualTransport(config) { setPipename(pipename); } -TPipe::TPipe(HANDLE PipeRd, HANDLE PipeWrt) - : impl_(new TAnonPipeImpl(PipeRd, PipeWrt)), TimeoutSeconds_(3), isAnonymous_(true) { +TPipe::TPipe(HANDLE PipeRd, HANDLE PipeWrt, std::shared_ptr<TConfiguration> config) + : impl_(new TAnonPipeImpl(PipeRd, PipeWrt)), TimeoutSeconds_(3), isAnonymous_(true), + TVirtualTransport(config) { } -TPipe::TPipe() : TimeoutSeconds_(3), isAnonymous_(false) { +TPipe::TPipe(std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3), isAnonymous_(false), + TVirtualTransport(config) { } TPipe::~TPipe() { @@ -299,6 +304,7 @@ void TPipe::close() { } uint32_t TPipe::read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); if (!isOpen()) throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open pipe"); return impl_->read(buf, len); diff --git a/lib/cpp/src/thrift/transport/TPipe.h b/lib/cpp/src/thrift/transport/TPipe.h index ba149b109..7795151a6 100644 --- a/lib/cpp/src/thrift/transport/TPipe.h +++ b/lib/cpp/src/thrift/transport/TPipe.h @@ -49,15 +49,15 @@ class TPipeImpl; class TPipe : public TVirtualTransport<TPipe> { public: // Constructs a new pipe object. - TPipe(); + TPipe(std::shared_ptr<TConfiguration> config = nullptr); // Named pipe constructors - - explicit TPipe(HANDLE Pipe); // HANDLE is a void* - explicit TPipe(TAutoHandle& Pipe); // this ctor will clear out / move from Pipe + explicit TPipe(HANDLE Pipe, std::shared_ptr<TConfiguration> config = nullptr); // HANDLE is a void* + explicit TPipe(TAutoHandle& Pipe, std::shared_ptr<TConfiguration> config = nullptr); // this ctor will clear out / move from Pipe // need a const char * overload so string literals don't go to the HANDLE overload - explicit TPipe(const char* pipename); - explicit TPipe(const std::string& pipename); + explicit TPipe(const char* pipename, std::shared_ptr<TConfiguration> config = nullptr); + explicit TPipe(const std::string& pipename, std::shared_ptr<TConfiguration> config = nullptr); // Anonymous pipe - - TPipe(HANDLE PipeRd, HANDLE PipeWrt); + TPipe(HANDLE PipeRd, HANDLE PipeWrt, std::shared_ptr<TConfiguration> config = nullptr); // Destroys the pipe object, closing it if necessary. virtual ~TPipe(); diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp index aa76980aa..9efc5fc50 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp +++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp @@ -214,34 +214,37 @@ SSL* SSLContext::createSSL() { } // TSSLSocket implementation -TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx) - : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) { +TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config) + : TSocket(config), server_(false), ssl_(nullptr), ctx_(ctx) { init(); } -TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener) - : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) { +TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener, + std::shared_ptr<TConfiguration> config) + : TSocket(config), server_(false), ssl_(nullptr), ctx_(ctx) { init(); interruptListener_ = interruptListener; } -TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket) - : TSocket(socket), server_(false), ssl_(nullptr), ctx_(ctx) { +TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config) + : TSocket(socket, config), server_(false), ssl_(nullptr), ctx_(ctx) { init(); } -TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener) - : TSocket(socket, interruptListener), server_(false), ssl_(nullptr), ctx_(ctx) { +TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener, + std::shared_ptr<TConfiguration> config) + : TSocket(socket, interruptListener, config), server_(false), ssl_(nullptr), ctx_(ctx) { init(); } -TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port) - : TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) { +TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<TConfiguration> config) + : TSocket(host, port, config), server_(false), ssl_(nullptr), ctx_(ctx) { init(); } -TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener) - : TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) { +TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener, + std::shared_ptr<TConfiguration> config) + : TSocket(host, port, config), server_(false), ssl_(nullptr), ctx_(ctx) { init(); interruptListener_ = interruptListener; } @@ -391,6 +394,7 @@ void TSSLSocket::close() { * exception incase of failure. */ uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); initializeHandshake(); if (!checkHandshake()) throw TTransportException(TTransportException::UNKNOWN, "retry again"); @@ -553,6 +557,7 @@ uint32_t TSSLSocket::write_partial(const uint8_t* buf, uint32_t len) { } void TSSLSocket::flush() { + resetConsumedMessageSize(); // Don't throw exception if not open. Thrift servers close socket twice. if (ssl_ == nullptr) { return; diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h index a78112c89..5afc571f9 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.h +++ b/lib/cpp/src/thrift/transport/TSSLSocket.h @@ -111,37 +111,40 @@ protected: /** * Constructor. */ - TSSLSocket(std::shared_ptr<SSLContext> ctx); + TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config = nullptr); /** * Constructor with an interrupt signal. */ - TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener); + TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener, + std::shared_ptr<TConfiguration> config = nullptr); /** * Constructor, create an instance of TSSLSocket given an existing socket. * * @param socket An existing socket */ - TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket); + TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr); /** * Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted. * * @param socket An existing socket */ - TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener); + TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener, + std::shared_ptr<TConfiguration> config = nullptr); /** * Constructor. * * @param host Remote host name * @param port Remote port number */ - TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port); + TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<TConfiguration> config = nullptr); /** * Constructor with an interrupt signal. * * @param host Remote host name * @param port Remote port number */ - TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener); + TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener, + std::shared_ptr<TConfiguration> config = nullptr); /** * Authorize peer access after SSL handshake completes. */ diff --git a/lib/cpp/src/thrift/transport/TShortReadTransport.h b/lib/cpp/src/thrift/transport/TShortReadTransport.h index 185c78dc7..c99e6a72b 100644 --- a/lib/cpp/src/thrift/transport/TShortReadTransport.h +++ b/lib/cpp/src/thrift/transport/TShortReadTransport.h @@ -38,8 +38,10 @@ namespace test { */ class TShortReadTransport : public TVirtualTransport<TShortReadTransport> { public: - TShortReadTransport(std::shared_ptr<TTransport> transport, double full_prob) - : transport_(transport), fullProb_(full_prob) {} + TShortReadTransport(std::shared_ptr<TTransport> transport, double full_prob, + std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), transport_(transport), fullProb_(full_prob) { + } bool isOpen() const override { return transport_->isOpen(); } @@ -50,6 +52,7 @@ public: void close() override { transport_->close(); } uint32_t read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); if (len == 0) { return 0; } @@ -62,11 +65,17 @@ public: void write(const uint8_t* buf, uint32_t len) { transport_->write(buf, len); } - void flush() override { transport_->flush(); } + void flush() override { + resetConsumedMessageSize(); + transport_->flush(); + } const uint8_t* borrow(uint8_t* buf, uint32_t* len) { return transport_->borrow(buf, len); } - void consume(uint32_t len) { return transport_->consume(len); } + void consume(uint32_t len) { + countConsumedMessageBytes(len); + return transport_->consume(len); + } std::shared_ptr<TTransport> getUnderlyingTransport() { return transport_; } diff --git a/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp b/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp index 4b1399e14..c41affb79 100644 --- a/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp +++ b/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp @@ -35,8 +35,8 @@ namespace apache { namespace thrift { namespace transport { -TSimpleFileTransport::TSimpleFileTransport(const std::string& path, bool read, bool write) - : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY) { +TSimpleFileTransport::TSimpleFileTransport(const std::string& path, bool read, bool write, std::shared_ptr<TConfiguration> config) + : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY, config) { int flags = 0; if (read && write) { flags = O_RDWR; diff --git a/lib/cpp/src/thrift/transport/TSimpleFileTransport.h b/lib/cpp/src/thrift/transport/TSimpleFileTransport.h index 32e18974d..24741b0f3 100644 --- a/lib/cpp/src/thrift/transport/TSimpleFileTransport.h +++ b/lib/cpp/src/thrift/transport/TSimpleFileTransport.h @@ -33,7 +33,8 @@ namespace transport { */ class TSimpleFileTransport : public TFDTransport { public: - TSimpleFileTransport(const std::string& path, bool read = true, bool write = false); + TSimpleFileTransport(const std::string& path, bool read = true, bool write = false, + std::shared_ptr<TConfiguration> config = nullptr); }; } } diff --git a/lib/cpp/src/thrift/transport/TSocket.cpp b/lib/cpp/src/thrift/transport/TSocket.cpp index a1a6dfb2f..81aaccf43 100644 --- a/lib/cpp/src/thrift/transport/TSocket.cpp +++ b/lib/cpp/src/thrift/transport/TSocket.cpp @@ -77,8 +77,9 @@ namespace transport { * */ -TSocket::TSocket(const string& host, int port) - : host_(host), +TSocket::TSocket(const string& host, int port, std::shared_ptr<TConfiguration> config) + : TVirtualTransport(config), + host_(host), port_(port), socket_(THRIFT_INVALID_SOCKET), peerPort_(0), @@ -92,8 +93,9 @@ TSocket::TSocket(const string& host, int port) maxRecvRetries_(5) { } -TSocket::TSocket(const string& path) - : port_(0), +TSocket::TSocket(const string& path, std::shared_ptr<TConfiguration> config) + : TVirtualTransport(config), + port_(0), path_(path), socket_(THRIFT_INVALID_SOCKET), peerPort_(0), @@ -108,8 +110,9 @@ TSocket::TSocket(const string& path) cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC; } -TSocket::TSocket() - : port_(0), +TSocket::TSocket(std::shared_ptr<TConfiguration> config) + : TVirtualTransport(config), + port_(0), socket_(THRIFT_INVALID_SOCKET), peerPort_(0), connTimeout_(0), @@ -123,8 +126,9 @@ TSocket::TSocket() cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC; } -TSocket::TSocket(THRIFT_SOCKET socket) - : port_(0), +TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config) + : TVirtualTransport(config), + port_(0), socket_(socket), peerPort_(0), connTimeout_(0), @@ -144,8 +148,10 @@ TSocket::TSocket(THRIFT_SOCKET socket) #endif } -TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener) - : port_(0), +TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener, + std::shared_ptr<TConfiguration> config) + : TVirtualTransport(config), + port_(0), socket_(socket), peerPort_(0), interruptListener_(interruptListener), @@ -522,6 +528,7 @@ void TSocket::setSocketFD(THRIFT_SOCKET socket) { } uint32_t TSocket::read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); if (socket_ == THRIFT_INVALID_SOCKET) { throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket"); } diff --git a/lib/cpp/src/thrift/transport/TSocket.h b/lib/cpp/src/thrift/transport/TSocket.h index b0e8ade83..043f0de8d 100644 --- a/lib/cpp/src/thrift/transport/TSocket.h +++ b/lib/cpp/src/thrift/transport/TSocket.h @@ -52,7 +52,7 @@ public: * socket. * */ - TSocket(); + TSocket(std::shared_ptr<TConfiguration> config = nullptr); /** * Constructs a new socket. Note that this does NOT actually connect the @@ -61,7 +61,7 @@ public: * @param host An IP address or hostname to connect to * @param port The port to connect on */ - TSocket(const std::string& host, int port); + TSocket(const std::string& host, int port, std::shared_ptr<TConfiguration> config = nullptr); /** * Constructs a new Unix domain socket. @@ -69,7 +69,7 @@ public: * * @param path The Unix domain socket e.g. "/tmp/ThriftTest.binary.thrift" */ - TSocket(const std::string& path); + TSocket(const std::string& path, std::shared_ptr<TConfiguration> config = nullptr); /** * Destroyes the socket object, closing it if necessary. @@ -264,13 +264,14 @@ public: /** * Constructor to create socket from file descriptor. */ - TSocket(THRIFT_SOCKET socket); + TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr); /** * Constructor to create socket from file descriptor that * can be interrupted safely. */ - TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener); + TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener, + std::shared_ptr<TConfiguration> config = nullptr); /** * Set a cache of the peer address (used when trivially available: e.g. diff --git a/lib/cpp/src/thrift/transport/TTransport.h b/lib/cpp/src/thrift/transport/TTransport.h index 63978829f..5f657f89b 100644 --- a/lib/cpp/src/thrift/transport/TTransport.h +++ b/lib/cpp/src/thrift/transport/TTransport.h @@ -21,6 +21,7 @@ #define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1 #include <thrift/Thrift.h> +#include <thrift/TConfiguration.h> #include <thrift/transport/TTransportException.h> #include <memory> #include <string> @@ -55,6 +56,15 @@ uint32_t readAll(Transport_& trans, uint8_t* buf, uint32_t len) { */ class TTransport { public: + TTransport(std::shared_ptr<TConfiguration> config = nullptr) { + if(config == nullptr) { + configuration_ = std::shared_ptr<TConfiguration> (new TConfiguration()); + } else { + configuration_ = config; + } + resetConsumedMessageSize(); + } + /** * Virtual deconstructor. */ @@ -238,11 +248,87 @@ public: */ virtual const std::string getOrigin() const { return "Unknown"; } + std::shared_ptr<TConfiguration> getConfiguration() { return configuration_; } + + void setConfiguration(std::shared_ptr<TConfiguration> config) { + if (config != nullptr) configuration_ = config; + } + + /** + * Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport). + * Will throw if we already consumed too many bytes or if the new size is larger than allowed. + * + * @param size real message size + */ + void updateKnownMessageSize(long int size) + { + long int consumed = knownMessageSize_ - remainingMessageSize_; + resetConsumedMessageSize(size); + countConsumedMessageBytes(consumed); + } + + /** + * Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data + * + * @param numBytes numBytes bytes of data + */ + void checkReadBytesAvailable(long int numBytes) + { + if (remainingMessageSize_ < numBytes) + throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached"); + } + protected: + std::shared_ptr<TConfiguration> configuration_; + long int remainingMessageSize_; + long int knownMessageSize_; + + inline long int getRemainingMessageSize() { return remainingMessageSize_; } + inline void setRemainingMessageSize(long int remainingMessageSize) { remainingMessageSize_ = remainingMessageSize; } + inline int getMaxMessageSize() { return configuration_->getMaxMessageSize(); } + inline long int getKnownMessageSize() { return knownMessageSize_; } + void setKnownMessageSize(long int knownMessageSize) { knownMessageSize_ = knownMessageSize; } + + /** + * Resets RemainingMessageSize to the configured maximum + * + * @param newSize configured size + */ + void resetConsumedMessageSize(long newSize = -1) + { + // full reset + if (newSize < 0) + { + knownMessageSize_ = getMaxMessageSize(); + remainingMessageSize_ = getMaxMessageSize(); + return; + } + + // update only: message size can shrink, but not grow + if (newSize > knownMessageSize_) + throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached"); + + knownMessageSize_ = newSize; + remainingMessageSize_ = newSize; + } + /** - * Simple constructor. + * Consumes numBytes from the RemainingMessageSize. + * + * @param numBytes Consumes numBytes */ - TTransport() = default; + void countConsumedMessageBytes(long int numBytes) + { + if (remainingMessageSize_ >= numBytes) + { + remainingMessageSize_ -= numBytes; + } + else + { + remainingMessageSize_ = 0; + throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached"); + } + } }; /** diff --git a/lib/cpp/src/thrift/transport/TTransportUtils.cpp b/lib/cpp/src/thrift/transport/TTransportUtils.cpp index 69372f3e2..427a2e7c1 100644 --- a/lib/cpp/src/thrift/transport/TTransportUtils.cpp +++ b/lib/cpp/src/thrift/transport/TTransportUtils.cpp @@ -26,6 +26,7 @@ namespace thrift { namespace transport { uint32_t TPipedTransport::read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); uint32_t need = len; // We don't have enough data yet @@ -104,8 +105,9 @@ void TPipedTransport::flush() { TPipedFileReaderTransport::TPipedFileReaderTransport( std::shared_ptr<TFileReaderTransport> srcTrans, - std::shared_ptr<TTransport> dstTrans) - : TPipedTransport(srcTrans, dstTrans), srcTrans_(srcTrans) { + std::shared_ptr<TTransport> dstTrans, + std::shared_ptr<TConfiguration> config) + : TPipedTransport(srcTrans, dstTrans, config), srcTrans_(srcTrans) { } TPipedFileReaderTransport::~TPipedFileReaderTransport() = default; @@ -131,6 +133,7 @@ uint32_t TPipedFileReaderTransport::read(uint8_t* buf, uint32_t len) { } uint32_t TPipedFileReaderTransport::readAll(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); uint32_t have = 0; uint32_t get = 0; diff --git a/lib/cpp/src/thrift/transport/TTransportUtils.h b/lib/cpp/src/thrift/transport/TTransportUtils.h index 28c93d2a1..68c25f4c9 100644 --- a/lib/cpp/src/thrift/transport/TTransportUtils.h +++ b/lib/cpp/src/thrift/transport/TTransportUtils.h @@ -63,8 +63,10 @@ public: */ class TPipedTransport : virtual public TTransport { public: - TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans) - : srcTrans_(srcTrans), + TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans, + std::shared_ptr<TConfiguration> config = nullptr) + : TTransport(config), + srcTrans_(srcTrans), dstTrans_(dstTrans), rBufSize_(512), rPos_(0), @@ -88,8 +90,10 @@ public: TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans, - uint32_t sz) - : srcTrans_(srcTrans), + uint32_t sz, + std::shared_ptr<TConfiguration> config = nullptr) + : TTransport(config), + srcTrans_(srcTrans), dstTrans_(dstTrans), rBufSize_(512), rPos_(0), @@ -241,7 +245,8 @@ protected: class TPipedFileReaderTransport : public TPipedTransport, public TFileReaderTransport { public: TPipedFileReaderTransport(std::shared_ptr<TFileReaderTransport> srcTrans, - std::shared_ptr<TTransport> dstTrans); + std::shared_ptr<TTransport> dstTrans, + std::shared_ptr<TConfiguration> config = nullptr); ~TPipedFileReaderTransport() override; diff --git a/lib/cpp/src/thrift/transport/TVirtualTransport.h b/lib/cpp/src/thrift/transport/TVirtualTransport.h index 0a0485742..44bfa1315 100644 --- a/lib/cpp/src/thrift/transport/TVirtualTransport.h +++ b/lib/cpp/src/thrift/transport/TVirtualTransport.h @@ -57,7 +57,7 @@ public: void consume(uint32_t len) { this->TTransport::consume_virt(len); } protected: - TTransportDefaults() = default; + TTransportDefaults(std::shared_ptr<TConfiguration> config = nullptr) : TTransport(config) {} }; /** @@ -118,7 +118,7 @@ public: } protected: - TVirtualTransport() = default; + TVirtualTransport() : Super_() {} /* * Templatized constructors, to allow arguments to be passed to the Super_ diff --git a/lib/cpp/src/thrift/transport/TWebSocketServer.h b/lib/cpp/src/thrift/transport/TWebSocketServer.h index 2e94c839d..7f39f36b9 100644 --- a/lib/cpp/src/thrift/transport/TWebSocketServer.h +++ b/lib/cpp/src/thrift/transport/TWebSocketServer.h @@ -53,8 +53,8 @@ std::string base64Encode(unsigned char* data, int length); template <bool binary> class TWebSocketServer : public THttpServer { public: - TWebSocketServer(std::shared_ptr<TTransport> transport) - : THttpServer(transport) { + TWebSocketServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr) + : THttpServer(transport, config) { resetHandshake(); } @@ -98,6 +98,7 @@ public: } void flush() override { + resetConsumedMessageSize(); writeFrameHeader(); uint8_t* buffer; uint32_t length; diff --git a/lib/cpp/src/thrift/transport/TZlibTransport.cpp b/lib/cpp/src/thrift/transport/TZlibTransport.cpp index b4c43d647..657ce5205 100644 --- a/lib/cpp/src/thrift/transport/TZlibTransport.cpp +++ b/lib/cpp/src/thrift/transport/TZlibTransport.cpp @@ -136,6 +136,7 @@ inline int TZlibTransport::readAvail() const { } uint32_t TZlibTransport::read(uint8_t* buf, uint32_t len) { + checkReadBytesAvailable(len); uint32_t need = len; // TODO(dreiss): Skip urbuf on big reads. @@ -265,6 +266,7 @@ void TZlibTransport::flush() { } flushToTransport(Z_FULL_FLUSH); + resetConsumedMessageSize(); } void TZlibTransport::finish() { @@ -335,6 +337,7 @@ const uint8_t* TZlibTransport::borrow(uint8_t* buf, uint32_t* len) { } void TZlibTransport::consume(uint32_t len) { + countConsumedMessageBytes(len); if (readAvail() >= (int)len) { urpos_ += len; } else { diff --git a/lib/cpp/src/thrift/transport/TZlibTransport.h b/lib/cpp/src/thrift/transport/TZlibTransport.h index 4990afff5..85765e6be 100644 --- a/lib/cpp/src/thrift/transport/TZlibTransport.h +++ b/lib/cpp/src/thrift/transport/TZlibTransport.h @@ -83,8 +83,10 @@ public: int crbuf_size = DEFAULT_CRBUF_SIZE, int uwbuf_size = DEFAULT_UWBUF_SIZE, int cwbuf_size = DEFAULT_CWBUF_SIZE, - int16_t comp_level = Z_DEFAULT_COMPRESSION) - : transport_(transport), + int16_t comp_level = Z_DEFAULT_COMPRESSION, + std::shared_ptr<TConfiguration> config = nullptr) + : TVirtualTransport(config), + transport_(transport), urpos_(0), uwpos_(0), input_ended_(false), diff --git a/lib/cpp/test/CMakeLists.txt b/lib/cpp/test/CMakeLists.txt index 48e2fd375..ced78a257 100644 --- a/lib/cpp/test/CMakeLists.txt +++ b/lib/cpp/test/CMakeLists.txt @@ -81,6 +81,7 @@ set(UnitTest_SOURCES TypedefTest.cpp TServerSocketTest.cpp TServerTransportTest.cpp + ThrifttReadCheckTests.cpp ) add_executable(UnitTests ${UnitTest_SOURCES}) diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am index 89826839d..7f630db10 100755 --- a/lib/cpp/test/Makefile.am +++ b/lib/cpp/test/Makefile.am @@ -130,7 +130,8 @@ UnitTests_SOURCES = \ TypedefTest.cpp \ TServerSocketTest.cpp \ TServerTransportTest.cpp \ - TTransportCheckThrow.h + TTransportCheckThrow.h \ + ThrifttReadCheckTests.cpp UnitTests_LDADD = \ libtestgencpp.la \ diff --git a/lib/cpp/test/ThrifttReadCheckTests.cpp b/lib/cpp/test/ThrifttReadCheckTests.cpp new file mode 100644 index 000000000..4a594e6ca --- /dev/null +++ b/lib/cpp/test/ThrifttReadCheckTests.cpp @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#define MAX_MESSAGE_SIZE 2 + +#include <boost/test/auto_unit_test.hpp> +#include <boost/test/unit_test.hpp> +#include <iostream> +#include <climits> +#include <vector> +#include <thrift/TConfiguration.h> +#include <thrift/protocol/TBinaryProtocol.h> +#include <thrift/protocol/TCompactProtocol.h> +#include <thrift/protocol/TJSONProtocol.h> +#include <thrift/Thrift.h> +#include <memory> +#include <thrift/transport/TTransportUtils.h> +#include <thrift/transport/TBufferTransports.h> +#include <thrift/transport/TSimpleFileTransport.h> +#include <thrift/transport/TFileTransport.h> +#include <thrift/protocol/TEnum.h> +#include <thrift/protocol/TList.h> +#include <thrift/protocol/TSet.h> +#include <thrift/protocol/TMap.h> + +BOOST_AUTO_TEST_SUITE(ThriftReadCheckExceptionTest) + +using apache::thrift::TConfiguration; +using apache::thrift::protocol::TBinaryProtocol; +using apache::thrift::protocol::TCompactProtocol; +using apache::thrift::protocol::TJSONProtocol; +using apache::thrift::protocol::TType; +using apache::thrift::transport::TPipedTransport; +using apache::thrift::transport::TMemoryBuffer; +using apache::thrift::transport::TSimpleFileTransport; +using apache::thrift::transport::TFileTransport; +using apache::thrift::transport::TFDTransport; +using apache::thrift::transport::TTransportException; +using apache::thrift::transport::TBufferedTransport; +using apache::thrift::transport::TFramedTransport; +using std::shared_ptr; +using std::cout; +using std::endl; +using std::string; +using std::memset; +using namespace apache::thrift; +using namespace apache::thrift::protocol; + + +BOOST_AUTO_TEST_CASE(test_tmemorybuffer_read_check_exception) { + std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE)); + TMemoryBuffer trans_out(config); + uint8_t buffer[6] = {1, 2, 3, 4, 5, 6}; + trans_out.write((const uint8_t*)buffer, sizeof(buffer)); + trans_out.close(); + + TMemoryBuffer trans_in(config); + memset(buffer, 0, sizeof(buffer)); + BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*); + trans_in.close(); +} + +BOOST_AUTO_TEST_CASE(test_tpipedtransport_read_check_exception) { + std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE)); + std::shared_ptr<TMemoryBuffer> pipe(new TMemoryBuffer); + std::shared_ptr<TMemoryBuffer> underlying(new TMemoryBuffer); + std::shared_ptr<TPipedTransport> trans(new TPipedTransport(underlying, pipe, config)); + + uint8_t buffer[4]; + + underlying->write((uint8_t*)"abcd", 4); + BOOST_CHECK_THROW(trans->read(buffer, sizeof(buffer)), TTransportException*); + BOOST_CHECK_THROW(trans->readAll(buffer, sizeof(buffer)), TTransportException*); + trans->readEnd(); + pipe->resetBuffer(); + underlying->write((uint8_t*)"ef", 2); + BOOST_CHECK_THROW(trans->read(buffer, sizeof(buffer)), TTransportException*); + BOOST_CHECK_THROW(trans->readAll(buffer, sizeof(buffer)), TTransportException*); + trans->readEnd(); +} + +BOOST_AUTO_TEST_CASE(test_tsimplefiletransport_read_check_exception) { + std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE)); + TSimpleFileTransport trans_out("data", false, true, config); + uint8_t buffer[6] = {1, 2, 3, 4, 5, 6}; + trans_out.write((const uint8_t*)buffer, sizeof(buffer)); + trans_out.close(); + + TSimpleFileTransport trans_in("data",true, false, config); + memset(buffer, 0, sizeof(buffer)); + BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*); + trans_in.close(); + + remove("./data"); +} + +BOOST_AUTO_TEST_CASE(test_tfiletransport_read_check_exception) { + std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE)); + TFileTransport trans_out("data", false, config); + uint8_t buffer[6] = {1, 2, 3, 4, 5, 6}; + trans_out.write((const uint8_t*)buffer, sizeof(buffer)); + + TFileTransport trans_in("data", false, config); + memset(buffer, 0, sizeof(buffer)); + BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*); + + remove("./data"); +} + +BOOST_AUTO_TEST_CASE(test_tbufferedtransport_read_check_exception) { + uint8_t arr[4] = {1, 2, 3, 4}; + std::shared_ptr<TMemoryBuffer> buffer (new TMemoryBuffer(arr, sizeof(arr))); + std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE)); + std::shared_ptr<TBufferedTransport> trans (new TBufferedTransport(buffer, config)); + + trans->write((const uint8_t*)arr, sizeof(arr)); + BOOST_CHECK_THROW(trans->read(arr, sizeof(arr)), TTransportException*); +} + +BOOST_AUTO_TEST_CASE(test_tframedtransport_read_check_exception) { + uint8_t arr[4] = {1, 2, 3, 4}; + std::shared_ptr<TMemoryBuffer> buffer (new TMemoryBuffer(arr, sizeof(arr))); + std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE)); + std::shared_ptr<TFramedTransport> trans (new TFramedTransport(buffer, config)); + + trans->write((const uint8_t*)arr, sizeof(arr)); + BOOST_CHECK_THROW(trans->read(arr, sizeof(arr)), TTransportException*); +} + +BOOST_AUTO_TEST_CASE(test_tthriftbinaryprotocol_read_check_exception) { + std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE)); + std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config)); + std::shared_ptr<TBinaryProtocol> protocol(new TBinaryProtocol(transport)); + + uint32_t val = 0; + TType elemType = apache::thrift::protocol::T_STOP; + TType elemType1 = apache::thrift::protocol::T_STOP; + TList list(T_I32, 8); + protocol->writeListBegin(list.elemType_, list.size_); + protocol->writeListEnd(); + BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*); + protocol->readListEnd(); + + TSet set(T_I32, 8); + protocol->writeSetBegin(set.elemType_, set.size_); + protocol->writeSetEnd(); + BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*); + protocol->readSetEnd(); + + TMap map(T_I32, T_I32, 8); + protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_); + protocol->writeMapEnd(); + BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*); + protocol->readMapEnd(); +} + +BOOST_AUTO_TEST_CASE(test_tthriftcompactprotocol_read_check_exception) { + std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE)); + std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config)); + std::shared_ptr<TCompactProtocol> protocol(new TCompactProtocol(transport)); + + uint32_t val = 0; + TType elemType = apache::thrift::protocol::T_STOP; + TType elemType1 = apache::thrift::protocol::T_STOP; + TList list(T_I32, 8); + protocol->writeListBegin(list.elemType_, list.size_); + protocol->writeListEnd(); + BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*); + protocol->readListEnd(); + + TSet set(T_I32, 8); + protocol->writeSetBegin(set.elemType_, set.size_); + protocol->writeSetEnd(); + BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*); + protocol->readSetEnd(); + + TMap map(T_I32, T_I32, 8); + protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_); + protocol->writeMapEnd(); + BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*); + protocol->readMapEnd(); +} + +BOOST_AUTO_TEST_CASE(test_tthriftjsonprotocol_read_check_exception) { + std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE)); + std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config)); + std::shared_ptr<TJSONProtocol> protocol(new TJSONProtocol(transport)); + + uint32_t val = 0; + TType elemType = apache::thrift::protocol::T_STOP; + TType elemType1 = apache::thrift::protocol::T_STOP; + TList list(T_I32, 8); + protocol->writeListBegin(list.elemType_, list.size_); + protocol->writeListEnd(); + BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*); + protocol->readListEnd(); + + TSet set(T_I32, 8); + protocol->writeSetBegin(set.elemType_, set.size_); + protocol->writeSetEnd(); + BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*); + protocol->readSetEnd(); + + TMap map(T_I32, T_I32, 8); + protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_); + protocol->writeMapEnd(); + BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*); + protocol->readMapEnd(); +} + +BOOST_AUTO_TEST_SUITE_END() |