diff options
author | Nobuaki Sukegawa <nsuke@apache.org> | 2016-01-11 13:46:04 +0900 |
---|---|---|
committer | Nobuaki Sukegawa <nsuke@apache.org> | 2016-01-13 20:25:23 +0900 |
commit | 7b545b57200ab960894e873716862cafbf9321f0 (patch) | |
tree | 56b13f14a20877edf7f7d33fe05e62e309f877ca | |
parent | 7e286b0d143be88adbd84f2e1cbfec66196a6a57 (diff) | |
download | thrift-7b545b57200ab960894e873716862cafbf9321f0.tar.gz |
THRIFT-3532 Add configurable string and container read size limit to Python protocols
This closes #787
-rw-r--r-- | compiler/cpp/src/generate/t_py_generator.cc | 4 | ||||
-rw-r--r-- | lib/py/src/protocol/TBase.py | 8 | ||||
-rw-r--r-- | lib/py/src/protocol/TBinaryProtocol.py | 37 | ||||
-rw-r--r-- | lib/py/src/protocol/TCompactProtocol.py | 30 | ||||
-rw-r--r-- | lib/py/src/protocol/TJSONProtocol.py | 18 | ||||
-rw-r--r-- | lib/py/src/protocol/TProtocol.py | 10 | ||||
-rw-r--r-- | lib/py/src/protocol/fastbinary.c | 57 | ||||
-rw-r--r-- | lib/py/src/transport/TTransport.py | 2 | ||||
-rw-r--r-- | test/features/known_failures_Linux.json | 10 | ||||
-rwxr-xr-x | test/py/TestServer.py | 16 |
10 files changed, 140 insertions, 52 deletions
diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc index 3c70248cc..2d30a7dca 100644 --- a/compiler/cpp/src/generate/t_py_generator.cc +++ b/compiler/cpp/src/generate/t_py_generator.cc @@ -902,10 +902,10 @@ void t_py_generator::generate_py_struct_reader(ofstream& out, t_struct* tstruct) if (is_immutable(tstruct)) { indent(out) - << "return fastbinary.decode_binary(None, iprot.trans, (cls, cls.thrift_spec))" + << "return fastbinary.decode_binary(None, iprot.trans, (cls, cls.thrift_spec), iprot.string_length_limit, iprot.container_length_limit)" << endl; } else { - indent(out) << "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))" + indent(out) << "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec), iprot.string_length_limit, iprot.container_length_limit)" << endl; indent(out) << "return" << endl; } diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py index 4f71e119e..d106f4e03 100644 --- a/lib/py/src/protocol/TBase.py +++ b/lib/py/src/protocol/TBase.py @@ -53,7 +53,9 @@ class TBase(object): fastbinary is not None): fastbinary.decode_binary(self, iprot.trans, - (self.__class__, self.thrift_spec)) + (self.__class__, self.thrift_spec), + iprot.string_length_limit, + iprot.container_length_limit) return iprot.readStruct(self, self.thrift_spec) @@ -90,5 +92,7 @@ class TFrozenBase(TBase): self = cls() return fastbinary.decode_binary(None, iprot.trans, - (self.__class__, self.thrift_spec)) + (self.__class__, self.thrift_spec), + iprot.string_length_limit, + iprot.container_length_limit) return iprot.readStruct(cls, cls.thrift_spec, True) diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py index 43cb5a476..db4ea3182 100644 --- a/lib/py/src/protocol/TBinaryProtocol.py +++ b/lib/py/src/protocol/TBinaryProtocol.py @@ -36,10 +36,18 @@ class TBinaryProtocol(TProtocolBase): TYPE_MASK = 0x000000ff - def __init__(self, trans, strictRead=False, strictWrite=True): + def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs): TProtocolBase.__init__(self, trans) self.strictRead = strictRead self.strictWrite = strictWrite + self.string_length_limit = kwargs.get('string_length_limit', None) + self.container_length_limit = kwargs.get('container_length_limit', None) + + def _check_string_length(self, length): + self._check_length(self.string_length_limit, length) + + def _check_container_length(self, length): + self._check_length(self.container_length_limit, length) def writeMessageBegin(self, name, type, seqid): if self.strictWrite: @@ -165,6 +173,7 @@ class TBinaryProtocol(TProtocolBase): ktype = self.readByte() vtype = self.readByte() size = self.readI32() + self._check_container_length(size) return (ktype, vtype, size) def readMapEnd(self): @@ -173,6 +182,7 @@ class TBinaryProtocol(TProtocolBase): def readListBegin(self): etype = self.readByte() size = self.readI32() + self._check_container_length(size) return (etype, size) def readListEnd(self): @@ -181,6 +191,7 @@ class TBinaryProtocol(TProtocolBase): def readSetBegin(self): etype = self.readByte() size = self.readI32() + self._check_container_length(size) return (etype, size) def readSetEnd(self): @@ -218,18 +229,23 @@ class TBinaryProtocol(TProtocolBase): return val def readBinary(self): - len = self.readI32() - s = self.trans.readAll(len) + size = self.readI32() + self._check_string_length(size) + s = self.trans.readAll(size) return s class TBinaryProtocolFactory(object): - def __init__(self, strictRead=False, strictWrite=True): + def __init__(self, strictRead=False, strictWrite=True, **kwargs): self.strictRead = strictRead self.strictWrite = strictWrite + self.string_length_limit = kwargs.get('string_length_limit', None) + self.container_length_limit = kwargs.get('container_length_limit', None) def getProtocol(self, trans): - prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) + prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite, + string_length_limit=self.string_length_limit, + container_length_limit=self.container_length_limit) return prot @@ -256,5 +272,14 @@ class TBinaryProtocolAccelerated(TBinaryProtocol): class TBinaryProtocolAcceleratedFactory(object): + def __init__(self, + string_length_limit=None, + container_length_limit=None): + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit + def getProtocol(self, trans): - return TBinaryProtocolAccelerated(trans) + return TBinaryProtocolAccelerated( + trans, + string_length_limit=self.string_length_limit, + container_length_limit=self.container_length_limit) diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py index 6023066d0..3d9c0e6e3 100644 --- a/lib/py/src/protocol/TCompactProtocol.py +++ b/lib/py/src/protocol/TCompactProtocol.py @@ -126,7 +126,9 @@ class TCompactProtocol(TProtocolBase): TYPE_BITS = 0x07 TYPE_SHIFT_AMOUNT = 5 - def __init__(self, trans): + def __init__(self, trans, + string_length_limit=None, + container_length_limit=None): TProtocolBase.__init__(self, trans) self.state = CLEAR self.__last_fid = 0 @@ -134,6 +136,14 @@ class TCompactProtocol(TProtocolBase): self.__bool_value = None self.__structs = [] self.__containers = [] + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit + + def _check_string_length(self, length): + self._check_length(self.string_length_limit, length) + + def _check_container_length(self, length): + self._check_length(self.container_length_limit, length) def __writeVarint(self, n): writeVarint(self.trans, n) @@ -344,6 +354,7 @@ class TCompactProtocol(TProtocolBase): type = self.__getTType(size_type) if size == 15: size = self.__readSize() + self._check_container_length(size) self.__containers.append(self.state) self.state = CONTAINER_READ return type, size @@ -353,6 +364,7 @@ class TCompactProtocol(TProtocolBase): def readMapBegin(self): assert self.state in (VALUE_READ, CONTAINER_READ), self.state size = self.__readSize() + self._check_container_length(size) types = 0 if size > 0: types = self.__readUByte() @@ -391,8 +403,9 @@ class TCompactProtocol(TProtocolBase): return val def __readBinary(self): - len = self.__readSize() - return self.trans.readAll(len) + size = self.__readSize() + self._check_string_length(size) + return self.trans.readAll(size) readBinary = reader(__readBinary) def __getTType(self, byte): @@ -400,8 +413,13 @@ class TCompactProtocol(TProtocolBase): class TCompactProtocolFactory(object): - def __init__(self): - pass + def __init__(self, + string_length_limit=None, + container_length_limit=None): + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit def getProtocol(self, trans): - return TCompactProtocol(trans) + return TCompactProtocol(trans, + self.string_length_limit, + self.container_length_limit) diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py index 3612e91a5..f9e65fbf2 100644 --- a/lib/py/src/protocol/TJSONProtocol.py +++ b/lib/py/src/protocol/TJSONProtocol.py @@ -175,6 +175,15 @@ class TJSONProtocolBase(TProtocolBase): self.resetWriteContext() self.resetReadContext() + # We don't have length limit implementation for JSON protocols + @property + def string_length_limit(senf): + return None + + @property + def container_length_limit(senf): + return None + def resetWriteContext(self): self.context = JSONBaseContext(self) self.contextStack = [self.context] @@ -560,10 +569,17 @@ class TJSONProtocol(TJSONProtocolBase): class TJSONProtocolFactory(object): - def getProtocol(self, trans): return TJSONProtocol(trans) + @property + def string_length_limit(senf): + return None + + @property + def container_length_limit(senf): + return None + class TSimpleJSONProtocol(TJSONProtocolBase): """Simple, readable, write-only JSON protocol. diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index 450e0fa7b..d9aa2e82b 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -18,6 +18,7 @@ # from thrift.Thrift import TException, TType, TFrozenDict +from thrift.transport.TTransport import TTransportException from ..compat import binary_to_str, str_to_binary import six @@ -48,6 +49,15 @@ class TProtocolBase(object): def __init__(self, trans): self.trans = trans + @staticmethod + def _check_length(limit, length): + if length < 0: + raise TTransportException(TTransportException.NEGATIVE_SIZE, + 'Negative length: %d' % length) + if limit is not None and length > limit: + raise TTransportException(TTransportException.SIZE_LIMIT, + 'Length exceeded max allowed: %d' % limit) + def writeMessageBegin(self, name, ttype, seqid): pass diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c index 337201b26..da57c8559 100644 --- a/lib/py/src/protocol/fastbinary.c +++ b/lib/py/src/protocol/fastbinary.c @@ -189,22 +189,19 @@ check_ssize_t_32(Py_ssize_t len) { return false; } if (!CHECK_RANGE(len, 0, INT32_MAX)) { - PyErr_SetString(PyExc_OverflowError, "string size out of range"); + PyErr_SetString(PyExc_OverflowError, "size out of range: exceeded INT32_MAX"); return false; } return true; } -#define MAX_LIST_SIZE (10000) - static inline bool -check_list_length(Py_ssize_t len) { - // error from getting the int - if (INT_CONV_ERROR_OCCURRED(len)) { +check_length_limit(Py_ssize_t len, long limit) { + if (!check_ssize_t_32(len)) { return false; } - if (!CHECK_RANGE(len, 0, MAX_LIST_SIZE)) { - PyErr_SetString(PyExc_OverflowError, "list size out of the sanity limit (10000 items max)"); + if (len > limit) { + PyErr_Format(PyExc_OverflowError, "size exceeded specified limit: %d", limit); return false; } return true; @@ -891,10 +888,10 @@ skip(DecodeBuffer* input, TType type) { /* --- HELPER FUNCTION FOR DECODE_VAL --- */ static PyObject* -decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs, long string_limit, long container_limit); static PyObject* -decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq) { +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq, long string_limit, long container_limit) { int spec_seq_len = PyTuple_Size(spec_seq); bool immutable = output == Py_None; PyObject* kwargs = NULL; @@ -954,7 +951,7 @@ decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* } } - fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); + fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs, string_limit, container_limit); if (fieldval == NULL) { goto error; } @@ -991,7 +988,7 @@ decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* // Returns a new reference. static PyObject* -decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs, long string_limit, long container_limit) { switch (type) { case T_BOOL: { @@ -1059,6 +1056,9 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { if (!readBytes(input, &buf, len)) { return NULL; } + if (!check_length_limit(len, string_limit)) { + return NULL; + } if (is_utf8(typeargs)) return PyUnicode_DecodeUTF8(buf, len, 0); @@ -1083,7 +1083,7 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { } len = readI32(input); - if (!check_list_length(len)) { + if (!check_length_limit(len, container_limit)) { return NULL; } @@ -1094,7 +1094,7 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { } for (i = 0; i < len; i++) { - PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); + PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs, string_limit, container_limit); if (!item) { Py_DECREF(ret); return NULL; @@ -1135,8 +1135,8 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { } len = readI32(input); - if (!check_ssize_t_32(len)) { - return false; + if (!check_length_limit(len, container_limit)) { + return NULL; } ret = PyDict_New(); @@ -1147,11 +1147,11 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { for (i = 0; i < len; i++) { PyObject* k = NULL; PyObject* v = NULL; - k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); + k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs, string_limit, container_limit); if (k == NULL) { goto loop_error; } - v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); + v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs, string_limit, container_limit); if (v == NULL) { goto loop_error; } @@ -1199,7 +1199,7 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { return NULL; } - return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec); + return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec, string_limit, container_limit); } case T_STOP: @@ -1213,6 +1213,15 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { } } +static long as_long_or(PyObject* value, long default_value) { + long v = PyInt_AsLong(value); + if (INT_CONV_ERROR_OCCURRED(v)) { + PyErr_Clear(); + return default_value; + } + return v; +} + /* --- TOP-LEVEL WRAPPER FOR INPUT -- */ @@ -1222,12 +1231,18 @@ decode_binary(PyObject *self, PyObject *args) { PyObject* transport = NULL; PyObject* typeargs = NULL; StructTypeArgs parsedargs; + PyObject* string_limit_obj = NULL; + PyObject* container_limit_obj = NULL; + long string_limit = 0; + long container_limit = 0; DecodeBuffer input = {0, 0}; PyObject* ret = NULL; - if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { + if (!PyArg_ParseTuple(args, "OOOOO", &output_obj, &transport, &typeargs, &string_limit_obj, &container_limit_obj)) { return NULL; } + string_limit = as_long_or(string_limit_obj, INT32_MAX); + container_limit = as_long_or(container_limit_obj, INT32_MAX); if (!parse_struct_args(&parsedargs, typeargs)) { return NULL; @@ -1237,7 +1252,7 @@ decode_binary(PyObject *self, PyObject *args) { return NULL; } - ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec); + ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec, string_limit, container_limit); free_decodebuf(&input); return ret; } diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py index 8e2da8dad..f99b3b9ba 100644 --- a/lib/py/src/transport/TTransport.py +++ b/lib/py/src/transport/TTransport.py @@ -30,6 +30,8 @@ class TTransportException(TException): ALREADY_OPEN = 2 TIMED_OUT = 3 END_OF_FILE = 4 + NEGATIVE_SIZE = 5 + SIZE_LIMIT = 6 def __init__(self, type=UNKNOWN, message=None): TException.__init__(self, message) diff --git a/test/features/known_failures_Linux.json b/test/features/known_failures_Linux.json index edff41a76..7a510839a 100644 --- a/test/features/known_failures_Linux.json +++ b/test/features/known_failures_Linux.json @@ -27,16 +27,6 @@ "nodejs-limit_string_length_compact_buffered-ip", "perl-limit_container_length_binary_buffered-ip", "perl-limit_string_length_binary_buffered-ip", - "py-limit_container_length_accel-binary_buffered-ip", - "py-limit_container_length_binary_buffered-ip", - "py-limit_container_length_compact_buffered-ip", - "py-limit_string_length_accel-binary_buffered-ip", - "py-limit_string_length_binary_buffered-ip", - "py-limit_string_length_compact_buffered-ip", - "py3-limit_container_length_binary_buffered-ip", - "py3-limit_container_length_compact_buffered-ip", - "py3-limit_string_length_binary_buffered-ip", - "py3-limit_string_length_compact_buffered-ip", "rb-limit_container_length_accel-binary_buffered-ip", "rb-limit_container_length_binary_buffered-ip", "rb-limit_container_length_compact_buffered-ip", diff --git a/test/py/TestServer.py b/test/py/TestServer.py index 4fa889460..f12a9fe76 100755 --- a/test/py/TestServer.py +++ b/test/py/TestServer.py @@ -18,7 +18,7 @@ # specific language governing permissions and limitations # under the License. # -from __future__ import division, print_function +from __future__ import division import glob import logging import os @@ -179,9 +179,6 @@ class TestHandler(object): def main(options): - # Print TServer log to stdout so that the test-runner can redirect it to log files - logging.basicConfig(level=logging.DEBUG) - # set up the protocol factory form the --protocol option prot_factories = { 'binary': TBinaryProtocol.TBinaryProtocolFactory, @@ -193,6 +190,12 @@ def main(options): if pfactory_cls is None: raise AssertionError('Unknown --protocol option: %s' % options.proto) pfactory = pfactory_cls() + try: + pfactory.string_length_limit = options.string_limit + pfactory.container_length_limit = options.container_limit + except: + # Ignore errors for those protocols that does not support length limit + pass # get the server type (TSimpleServer, TNonblockingServer, etc...) if len(args) > 1: @@ -287,9 +290,14 @@ if __name__ == '__main__': help="protocol to use, one of: accel, binary, compact, json") parser.add_option('--transport', dest="trans", type="string", help="transport to use, one of: buffered, framed") + parser.add_option('--container-limit', dest='container_limit', type='int', default=None) + parser.add_option('--string-limit', dest='string_limit', type='int', default=None) parser.set_defaults(port=9090, verbose=1, proto='binary') options, args = parser.parse_args() + # Print TServer log to stdout so that the test-runner can redirect it to log files + logging.basicConfig(level=options.verbose) + sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir)) if options.libpydir: sys.path.insert(0, glob.glob(options.libpydir)[0]) |