summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNobuaki Sukegawa <nsuke@apache.org>2016-01-11 13:46:04 +0900
committerNobuaki Sukegawa <nsuke@apache.org>2016-01-13 20:25:23 +0900
commit7b545b57200ab960894e873716862cafbf9321f0 (patch)
tree56b13f14a20877edf7f7d33fe05e62e309f877ca
parent7e286b0d143be88adbd84f2e1cbfec66196a6a57 (diff)
downloadthrift-7b545b57200ab960894e873716862cafbf9321f0.tar.gz
THRIFT-3532 Add configurable string and container read size limit to Python protocols
This closes #787
-rw-r--r--compiler/cpp/src/generate/t_py_generator.cc4
-rw-r--r--lib/py/src/protocol/TBase.py8
-rw-r--r--lib/py/src/protocol/TBinaryProtocol.py37
-rw-r--r--lib/py/src/protocol/TCompactProtocol.py30
-rw-r--r--lib/py/src/protocol/TJSONProtocol.py18
-rw-r--r--lib/py/src/protocol/TProtocol.py10
-rw-r--r--lib/py/src/protocol/fastbinary.c57
-rw-r--r--lib/py/src/transport/TTransport.py2
-rw-r--r--test/features/known_failures_Linux.json10
-rwxr-xr-xtest/py/TestServer.py16
10 files changed, 140 insertions, 52 deletions
diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc
index 3c70248cc..2d30a7dca 100644
--- a/compiler/cpp/src/generate/t_py_generator.cc
+++ b/compiler/cpp/src/generate/t_py_generator.cc
@@ -902,10 +902,10 @@ void t_py_generator::generate_py_struct_reader(ofstream& out, t_struct* tstruct)
if (is_immutable(tstruct)) {
indent(out)
- << "return fastbinary.decode_binary(None, iprot.trans, (cls, cls.thrift_spec))"
+ << "return fastbinary.decode_binary(None, iprot.trans, (cls, cls.thrift_spec), iprot.string_length_limit, iprot.container_length_limit)"
<< endl;
} else {
- indent(out) << "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))"
+ indent(out) << "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec), iprot.string_length_limit, iprot.container_length_limit)"
<< endl;
indent(out) << "return" << endl;
}
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 4f71e119e..d106f4e03 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -53,7 +53,9 @@ class TBase(object):
fastbinary is not None):
fastbinary.decode_binary(self,
iprot.trans,
- (self.__class__, self.thrift_spec))
+ (self.__class__, self.thrift_spec),
+ iprot.string_length_limit,
+ iprot.container_length_limit)
return
iprot.readStruct(self, self.thrift_spec)
@@ -90,5 +92,7 @@ class TFrozenBase(TBase):
self = cls()
return fastbinary.decode_binary(None,
iprot.trans,
- (self.__class__, self.thrift_spec))
+ (self.__class__, self.thrift_spec),
+ iprot.string_length_limit,
+ iprot.container_length_limit)
return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 43cb5a476..db4ea3182 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -36,10 +36,18 @@ class TBinaryProtocol(TProtocolBase):
TYPE_MASK = 0x000000ff
- def __init__(self, trans, strictRead=False, strictWrite=True):
+ def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
TProtocolBase.__init__(self, trans)
self.strictRead = strictRead
self.strictWrite = strictWrite
+ self.string_length_limit = kwargs.get('string_length_limit', None)
+ self.container_length_limit = kwargs.get('container_length_limit', None)
+
+ def _check_string_length(self, length):
+ self._check_length(self.string_length_limit, length)
+
+ def _check_container_length(self, length):
+ self._check_length(self.container_length_limit, length)
def writeMessageBegin(self, name, type, seqid):
if self.strictWrite:
@@ -165,6 +173,7 @@ class TBinaryProtocol(TProtocolBase):
ktype = self.readByte()
vtype = self.readByte()
size = self.readI32()
+ self._check_container_length(size)
return (ktype, vtype, size)
def readMapEnd(self):
@@ -173,6 +182,7 @@ class TBinaryProtocol(TProtocolBase):
def readListBegin(self):
etype = self.readByte()
size = self.readI32()
+ self._check_container_length(size)
return (etype, size)
def readListEnd(self):
@@ -181,6 +191,7 @@ class TBinaryProtocol(TProtocolBase):
def readSetBegin(self):
etype = self.readByte()
size = self.readI32()
+ self._check_container_length(size)
return (etype, size)
def readSetEnd(self):
@@ -218,18 +229,23 @@ class TBinaryProtocol(TProtocolBase):
return val
def readBinary(self):
- len = self.readI32()
- s = self.trans.readAll(len)
+ size = self.readI32()
+ self._check_string_length(size)
+ s = self.trans.readAll(size)
return s
class TBinaryProtocolFactory(object):
- def __init__(self, strictRead=False, strictWrite=True):
+ def __init__(self, strictRead=False, strictWrite=True, **kwargs):
self.strictRead = strictRead
self.strictWrite = strictWrite
+ self.string_length_limit = kwargs.get('string_length_limit', None)
+ self.container_length_limit = kwargs.get('container_length_limit', None)
def getProtocol(self, trans):
- prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite)
+ prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
+ string_length_limit=self.string_length_limit,
+ container_length_limit=self.container_length_limit)
return prot
@@ -256,5 +272,14 @@ class TBinaryProtocolAccelerated(TBinaryProtocol):
class TBinaryProtocolAcceleratedFactory(object):
+ def __init__(self,
+ string_length_limit=None,
+ container_length_limit=None):
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
+
def getProtocol(self, trans):
- return TBinaryProtocolAccelerated(trans)
+ return TBinaryProtocolAccelerated(
+ trans,
+ string_length_limit=self.string_length_limit,
+ container_length_limit=self.container_length_limit)
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 6023066d0..3d9c0e6e3 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -126,7 +126,9 @@ class TCompactProtocol(TProtocolBase):
TYPE_BITS = 0x07
TYPE_SHIFT_AMOUNT = 5
- def __init__(self, trans):
+ def __init__(self, trans,
+ string_length_limit=None,
+ container_length_limit=None):
TProtocolBase.__init__(self, trans)
self.state = CLEAR
self.__last_fid = 0
@@ -134,6 +136,14 @@ class TCompactProtocol(TProtocolBase):
self.__bool_value = None
self.__structs = []
self.__containers = []
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
+
+ def _check_string_length(self, length):
+ self._check_length(self.string_length_limit, length)
+
+ def _check_container_length(self, length):
+ self._check_length(self.container_length_limit, length)
def __writeVarint(self, n):
writeVarint(self.trans, n)
@@ -344,6 +354,7 @@ class TCompactProtocol(TProtocolBase):
type = self.__getTType(size_type)
if size == 15:
size = self.__readSize()
+ self._check_container_length(size)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return type, size
@@ -353,6 +364,7 @@ class TCompactProtocol(TProtocolBase):
def readMapBegin(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
size = self.__readSize()
+ self._check_container_length(size)
types = 0
if size > 0:
types = self.__readUByte()
@@ -391,8 +403,9 @@ class TCompactProtocol(TProtocolBase):
return val
def __readBinary(self):
- len = self.__readSize()
- return self.trans.readAll(len)
+ size = self.__readSize()
+ self._check_string_length(size)
+ return self.trans.readAll(size)
readBinary = reader(__readBinary)
def __getTType(self, byte):
@@ -400,8 +413,13 @@ class TCompactProtocol(TProtocolBase):
class TCompactProtocolFactory(object):
- def __init__(self):
- pass
+ def __init__(self,
+ string_length_limit=None,
+ container_length_limit=None):
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
def getProtocol(self, trans):
- return TCompactProtocol(trans)
+ return TCompactProtocol(trans,
+ self.string_length_limit,
+ self.container_length_limit)
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index 3612e91a5..f9e65fbf2 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -175,6 +175,15 @@ class TJSONProtocolBase(TProtocolBase):
self.resetWriteContext()
self.resetReadContext()
+ # We don't have length limit implementation for JSON protocols
+ @property
+ def string_length_limit(senf):
+ return None
+
+ @property
+ def container_length_limit(senf):
+ return None
+
def resetWriteContext(self):
self.context = JSONBaseContext(self)
self.contextStack = [self.context]
@@ -560,10 +569,17 @@ class TJSONProtocol(TJSONProtocolBase):
class TJSONProtocolFactory(object):
-
def getProtocol(self, trans):
return TJSONProtocol(trans)
+ @property
+ def string_length_limit(senf):
+ return None
+
+ @property
+ def container_length_limit(senf):
+ return None
+
class TSimpleJSONProtocol(TJSONProtocolBase):
"""Simple, readable, write-only JSON protocol.
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index 450e0fa7b..d9aa2e82b 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -18,6 +18,7 @@
#
from thrift.Thrift import TException, TType, TFrozenDict
+from thrift.transport.TTransport import TTransportException
from ..compat import binary_to_str, str_to_binary
import six
@@ -48,6 +49,15 @@ class TProtocolBase(object):
def __init__(self, trans):
self.trans = trans
+ @staticmethod
+ def _check_length(limit, length):
+ if length < 0:
+ raise TTransportException(TTransportException.NEGATIVE_SIZE,
+ 'Negative length: %d' % length)
+ if limit is not None and length > limit:
+ raise TTransportException(TTransportException.SIZE_LIMIT,
+ 'Length exceeded max allowed: %d' % limit)
+
def writeMessageBegin(self, name, ttype, seqid):
pass
diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c
index 337201b26..da57c8559 100644
--- a/lib/py/src/protocol/fastbinary.c
+++ b/lib/py/src/protocol/fastbinary.c
@@ -189,22 +189,19 @@ check_ssize_t_32(Py_ssize_t len) {
return false;
}
if (!CHECK_RANGE(len, 0, INT32_MAX)) {
- PyErr_SetString(PyExc_OverflowError, "string size out of range");
+ PyErr_SetString(PyExc_OverflowError, "size out of range: exceeded INT32_MAX");
return false;
}
return true;
}
-#define MAX_LIST_SIZE (10000)
-
static inline bool
-check_list_length(Py_ssize_t len) {
- // error from getting the int
- if (INT_CONV_ERROR_OCCURRED(len)) {
+check_length_limit(Py_ssize_t len, long limit) {
+ if (!check_ssize_t_32(len)) {
return false;
}
- if (!CHECK_RANGE(len, 0, MAX_LIST_SIZE)) {
- PyErr_SetString(PyExc_OverflowError, "list size out of the sanity limit (10000 items max)");
+ if (len > limit) {
+ PyErr_Format(PyExc_OverflowError, "size exceeded specified limit: %d", limit);
return false;
}
return true;
@@ -891,10 +888,10 @@ skip(DecodeBuffer* input, TType type) {
/* --- HELPER FUNCTION FOR DECODE_VAL --- */
static PyObject*
-decode_val(DecodeBuffer* input, TType type, PyObject* typeargs);
+decode_val(DecodeBuffer* input, TType type, PyObject* typeargs, long string_limit, long container_limit);
static PyObject*
-decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq) {
+decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq, long string_limit, long container_limit) {
int spec_seq_len = PyTuple_Size(spec_seq);
bool immutable = output == Py_None;
PyObject* kwargs = NULL;
@@ -954,7 +951,7 @@ decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject*
}
}
- fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs);
+ fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs, string_limit, container_limit);
if (fieldval == NULL) {
goto error;
}
@@ -991,7 +988,7 @@ decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject*
// Returns a new reference.
static PyObject*
-decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
+decode_val(DecodeBuffer* input, TType type, PyObject* typeargs, long string_limit, long container_limit) {
switch (type) {
case T_BOOL: {
@@ -1059,6 +1056,9 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
if (!readBytes(input, &buf, len)) {
return NULL;
}
+ if (!check_length_limit(len, string_limit)) {
+ return NULL;
+ }
if (is_utf8(typeargs))
return PyUnicode_DecodeUTF8(buf, len, 0);
@@ -1083,7 +1083,7 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
}
len = readI32(input);
- if (!check_list_length(len)) {
+ if (!check_length_limit(len, container_limit)) {
return NULL;
}
@@ -1094,7 +1094,7 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
}
for (i = 0; i < len; i++) {
- PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs);
+ PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs, string_limit, container_limit);
if (!item) {
Py_DECREF(ret);
return NULL;
@@ -1135,8 +1135,8 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
}
len = readI32(input);
- if (!check_ssize_t_32(len)) {
- return false;
+ if (!check_length_limit(len, container_limit)) {
+ return NULL;
}
ret = PyDict_New();
@@ -1147,11 +1147,11 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
for (i = 0; i < len; i++) {
PyObject* k = NULL;
PyObject* v = NULL;
- k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs);
+ k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs, string_limit, container_limit);
if (k == NULL) {
goto loop_error;
}
- v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs);
+ v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs, string_limit, container_limit);
if (v == NULL) {
goto loop_error;
}
@@ -1199,7 +1199,7 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
return NULL;
}
- return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec);
+ return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec, string_limit, container_limit);
}
case T_STOP:
@@ -1213,6 +1213,15 @@ decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
}
}
+static long as_long_or(PyObject* value, long default_value) {
+ long v = PyInt_AsLong(value);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ PyErr_Clear();
+ return default_value;
+ }
+ return v;
+}
+
/* --- TOP-LEVEL WRAPPER FOR INPUT -- */
@@ -1222,12 +1231,18 @@ decode_binary(PyObject *self, PyObject *args) {
PyObject* transport = NULL;
PyObject* typeargs = NULL;
StructTypeArgs parsedargs;
+ PyObject* string_limit_obj = NULL;
+ PyObject* container_limit_obj = NULL;
+ long string_limit = 0;
+ long container_limit = 0;
DecodeBuffer input = {0, 0};
PyObject* ret = NULL;
- if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) {
+ if (!PyArg_ParseTuple(args, "OOOOO", &output_obj, &transport, &typeargs, &string_limit_obj, &container_limit_obj)) {
return NULL;
}
+ string_limit = as_long_or(string_limit_obj, INT32_MAX);
+ container_limit = as_long_or(container_limit_obj, INT32_MAX);
if (!parse_struct_args(&parsedargs, typeargs)) {
return NULL;
@@ -1237,7 +1252,7 @@ decode_binary(PyObject *self, PyObject *args) {
return NULL;
}
- ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec);
+ ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec, string_limit, container_limit);
free_decodebuf(&input);
return ret;
}
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 8e2da8dad..f99b3b9ba 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -30,6 +30,8 @@ class TTransportException(TException):
ALREADY_OPEN = 2
TIMED_OUT = 3
END_OF_FILE = 4
+ NEGATIVE_SIZE = 5
+ SIZE_LIMIT = 6
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
diff --git a/test/features/known_failures_Linux.json b/test/features/known_failures_Linux.json
index edff41a76..7a510839a 100644
--- a/test/features/known_failures_Linux.json
+++ b/test/features/known_failures_Linux.json
@@ -27,16 +27,6 @@
"nodejs-limit_string_length_compact_buffered-ip",
"perl-limit_container_length_binary_buffered-ip",
"perl-limit_string_length_binary_buffered-ip",
- "py-limit_container_length_accel-binary_buffered-ip",
- "py-limit_container_length_binary_buffered-ip",
- "py-limit_container_length_compact_buffered-ip",
- "py-limit_string_length_accel-binary_buffered-ip",
- "py-limit_string_length_binary_buffered-ip",
- "py-limit_string_length_compact_buffered-ip",
- "py3-limit_container_length_binary_buffered-ip",
- "py3-limit_container_length_compact_buffered-ip",
- "py3-limit_string_length_binary_buffered-ip",
- "py3-limit_string_length_compact_buffered-ip",
"rb-limit_container_length_accel-binary_buffered-ip",
"rb-limit_container_length_binary_buffered-ip",
"rb-limit_container_length_compact_buffered-ip",
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index 4fa889460..f12a9fe76 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -18,7 +18,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from __future__ import division, print_function
+from __future__ import division
import glob
import logging
import os
@@ -179,9 +179,6 @@ class TestHandler(object):
def main(options):
- # Print TServer log to stdout so that the test-runner can redirect it to log files
- logging.basicConfig(level=logging.DEBUG)
-
# set up the protocol factory form the --protocol option
prot_factories = {
'binary': TBinaryProtocol.TBinaryProtocolFactory,
@@ -193,6 +190,12 @@ def main(options):
if pfactory_cls is None:
raise AssertionError('Unknown --protocol option: %s' % options.proto)
pfactory = pfactory_cls()
+ try:
+ pfactory.string_length_limit = options.string_limit
+ pfactory.container_length_limit = options.container_limit
+ except:
+ # Ignore errors for those protocols that does not support length limit
+ pass
# get the server type (TSimpleServer, TNonblockingServer, etc...)
if len(args) > 1:
@@ -287,9 +290,14 @@ if __name__ == '__main__':
help="protocol to use, one of: accel, binary, compact, json")
parser.add_option('--transport', dest="trans", type="string",
help="transport to use, one of: buffered, framed")
+ parser.add_option('--container-limit', dest='container_limit', type='int', default=None)
+ parser.add_option('--string-limit', dest='string_limit', type='int', default=None)
parser.set_defaults(port=9090, verbose=1, proto='binary')
options, args = parser.parse_args()
+ # Print TServer log to stdout so that the test-runner can redirect it to log files
+ logging.basicConfig(level=options.verbose)
+
sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir))
if options.libpydir:
sys.path.insert(0, glob.glob(options.libpydir)[0])