summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--simplejson/_speedups.c103
-rw-r--r--simplejson/tests/test_speedups.py55
2 files changed, 123 insertions, 35 deletions
diff --git a/simplejson/_speedups.c b/simplejson/_speedups.c
index 212616a..291f56e 100644
--- a/simplejson/_speedups.c
+++ b/simplejson/_speedups.c
@@ -1440,10 +1440,12 @@ _parse_object_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_
PyObject *key = NULL;
PyObject *val = NULL;
char *encoding = JSON_ASCII_AS_STRING(s->encoding);
- int strict = PyObject_IsTrue(s->strict);
int has_pairs_hook = (s->pairs_hook != Py_None);
int did_parse = 0;
Py_ssize_t next_idx;
+ int strict = PyObject_IsTrue(s->strict);
+ if (strict < 0)
+ return NULL;
if (has_pairs_hook) {
pairs = PyList_New(0);
if (pairs == NULL)
@@ -1601,11 +1603,13 @@ _parse_object_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ss
PyObject *item;
PyObject *key = NULL;
PyObject *val = NULL;
- int strict = PyObject_IsTrue(s->strict);
int has_pairs_hook = (s->pairs_hook != Py_None);
int did_parse = 0;
Py_ssize_t next_idx;
+ int strict = PyObject_IsTrue(s->strict);
+ if (strict < 0)
+ return NULL;
if (has_pairs_hook) {
pairs = PyList_New(0);
if (pairs == NULL)
@@ -2167,6 +2171,7 @@ scan_once_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *n
Py_ssize_t length = PyString_GET_SIZE(pystr);
PyObject *rval = NULL;
int fallthrough = 0;
+ int strict;
if (idx < 0 || idx >= length) {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
return NULL;
@@ -2174,9 +2179,12 @@ scan_once_str(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_t *n
switch (str[idx]) {
case '"':
/* string */
+ strict = PyObject_IsTrue(s->strict)
+ if (strict < 0)
+ return NULL;
rval = scanstring_str(pystr, idx + 1,
JSON_ASCII_AS_STRING(s->encoding),
- PyObject_IsTrue(s->strict),
+ strict,
next_idx_ptr);
break;
case '{':
@@ -2275,6 +2283,7 @@ scan_once_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_
Py_ssize_t length = PyUnicode_GetLength(pystr);
PyObject *rval = NULL;
int fallthrough = 0;
+ int strict;
if (idx < 0 || idx >= length) {
raise_errmsg(ERR_EXPECTING_VALUE, pystr, idx);
return NULL;
@@ -2282,8 +2291,11 @@ scan_once_unicode(PyScannerObject *s, PyObject *pystr, Py_ssize_t idx, Py_ssize_
switch (PyUnicode_READ(kind, str, idx)) {
case '"':
/* string */
+ strict = PyObject_IsTrue(s->strict);
+ if (strict < 0)
+ return NULL;
rval = scanstring_unicode(pystr, idx + 1,
- PyObject_IsTrue(s->strict),
+ strict,
next_idx_ptr);
break;
case '{':
@@ -2577,6 +2589,7 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyObject *use_decimal, *namedtuple_as_object, *tuple_as_array, *iterable_as_array;
PyObject *int_as_string_bitcount, *item_sort_key, *encoding, *for_json;
PyObject *ignore_nan, *Decimal;
+ int is_true;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOOOOOOO:make_encoder", kwlist,
&markers, &defaultfn, &encoder, &indent, &key_separator, &item_separator,
@@ -2608,29 +2621,44 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
Py_INCREF(skipkeys);
s->skipkeys_bool = skipkeys;
s->skipkeys = PyObject_IsTrue(skipkeys);
+ if (s->skipkeys < 0)
+ goto bail;
Py_INCREF(key_memo);
s->key_memo = key_memo;
s->fast_encode = (PyCFunction_Check(s->encoder) && PyCFunction_GetFunction(s->encoder) == (PyCFunction)py_encode_basestring_ascii);
- s->allow_or_ignore_nan = (
- (PyObject_IsTrue(ignore_nan) ? JSON_IGNORE_NAN : 0) |
- (PyObject_IsTrue(allow_nan) ? JSON_ALLOW_NAN : 0));
+ is_true = PyObject_IsTrue(ignore_nan);
+ if (is_true < 0)
+ goto bail;
+ s->allow_or_ignore_nan = is_true ? JSON_IGNORE_NAN : 0;
+ is_true = PyObject_IsTrue(allow_nan);
+ if (is_true < 0)
+ goto bail;
+ s->allow_or_ignore_nan |= is_true ? JSON_ALLOW_NAN : 0;
s->use_decimal = PyObject_IsTrue(use_decimal);
+ if (s->use_decimal < 0)
+ goto bail;
s->namedtuple_as_object = PyObject_IsTrue(namedtuple_as_object);
+ if (s->namedtuple_as_object < 0)
+ goto bail;
s->tuple_as_array = PyObject_IsTrue(tuple_as_array);
+ if (s->tuple_as_array < 0)
+ goto bail;
s->iterable_as_array = PyObject_IsTrue(iterable_as_array);
+ if (s->iterable_as_array < 0)
+ goto bail;
if (PyInt_Check(int_as_string_bitcount) || PyLong_Check(int_as_string_bitcount)) {
- static const unsigned int long_long_bitsize = SIZEOF_LONG_LONG * 8;
- int int_as_string_bitcount_val = (int)PyLong_AsLong(int_as_string_bitcount);
- if (int_as_string_bitcount_val > 0 && int_as_string_bitcount_val < (int)long_long_bitsize) {
- s->max_long_size = PyLong_FromUnsignedLongLong(1ULL << int_as_string_bitcount_val);
- s->min_long_size = PyLong_FromLongLong(-1LL << int_as_string_bitcount_val);
+ static const unsigned long long_long_bitsize = SIZEOF_LONG_LONG * 8;
+ long int_as_string_bitcount_val = PyLong_AsLong(int_as_string_bitcount);
+ if (int_as_string_bitcount_val > 0 && int_as_string_bitcount_val < (long)long_long_bitsize) {
+ s->max_long_size = PyLong_FromUnsignedLongLong(1ULL << (int)int_as_string_bitcount_val);
+ s->min_long_size = PyLong_FromLongLong(-1LL << (int)int_as_string_bitcount_val);
if (s->min_long_size == NULL || s->max_long_size == NULL) {
goto bail;
}
}
else {
PyErr_Format(PyExc_TypeError,
- "int_as_string_bitcount (%d) must be greater than 0 and less than the number of bits of a `long long` type (%u bits)",
+ "int_as_string_bitcount (%ld) must be greater than 0 and less than the number of bits of a `long long` type (%lu bits)",
int_as_string_bitcount_val, long_long_bitsize);
goto bail;
}
@@ -2651,29 +2679,34 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
goto bail;
}
}
- else if (PyObject_IsTrue(sort_keys)) {
- static PyObject *itemgetter0 = NULL;
- if (!itemgetter0) {
- PyObject *operator = PyImport_ImportModule("operator");
- if (!operator)
- goto bail;
- itemgetter0 = PyObject_CallMethod(operator, "itemgetter", "i", 0);
- Py_DECREF(operator);
- }
- item_sort_key = itemgetter0;
- if (!item_sort_key)
- goto bail;
- }
- if (item_sort_key == Py_None) {
- Py_INCREF(Py_None);
- s->item_sort_kw = Py_None;
- }
else {
- s->item_sort_kw = PyDict_New();
- if (s->item_sort_kw == NULL)
- goto bail;
- if (PyDict_SetItemString(s->item_sort_kw, "key", item_sort_key))
+ is_true = PyObject_IsTrue(sort_keys);
+ if (is_true < 0)
goto bail;
+ if (is_true) {
+ static PyObject *itemgetter0 = NULL;
+ if (!itemgetter0) {
+ PyObject *operator = PyImport_ImportModule("operator");
+ if (!operator)
+ goto bail;
+ itemgetter0 = PyObject_CallMethod(operator, "itemgetter", "i", 0);
+ Py_DECREF(operator);
+ }
+ item_sort_key = itemgetter0;
+ if (!item_sort_key)
+ goto bail;
+ }
+ if (item_sort_key == Py_None) {
+ Py_INCREF(Py_None);
+ s->item_sort_kw = Py_None;
+ }
+ else {
+ s->item_sort_kw = PyDict_New();
+ if (s->item_sort_kw == NULL)
+ goto bail;
+ if (PyDict_SetItemString(s->item_sort_kw, "key", item_sort_key))
+ goto bail;
+ }
}
Py_INCREF(sort_keys);
s->sort_keys = sort_keys;
@@ -2682,6 +2715,8 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
Py_INCREF(Decimal);
s->Decimal = Decimal;
s->for_json = PyObject_IsTrue(for_json);
+ if (s->for_json < 0)
+ goto bail;
return (PyObject *)s;
diff --git a/simplejson/tests/test_speedups.py b/simplejson/tests/test_speedups.py
index 0a2b63b..25e76dc 100644
--- a/simplejson/tests/test_speedups.py
+++ b/simplejson/tests/test_speedups.py
@@ -2,7 +2,9 @@ import sys
import unittest
from unittest import TestCase
-from simplejson import encoder, scanner
+import simplejson
+from simplejson import encoder, decoder, scanner
+from simplejson.compat import PY3
def has_speedups():
@@ -22,12 +24,31 @@ def skip_if_speedups_missing(func):
return wrapper
+class BadBool:
+ def __bool__(self):
+ 1/0
+ __nonzero__ = __bool__
+
+
class TestDecode(TestCase):
@skip_if_speedups_missing
def test_make_scanner(self):
self.assertRaises(AttributeError, scanner.c_make_scanner, 1)
@skip_if_speedups_missing
+ def test_bad_bool_args(self):
+ with self.assertRaises(ZeroDivisionError):
+ decoder.JSONDecoder(strict=BadBool()).decode('""')
+ with self.assertRaises(ZeroDivisionError):
+ decoder.JSONDecoder(strict=BadBool()).decode('{}')
+ if not PY3:
+ with self.assertRaises(ZeroDivisionError):
+ decoder.JSONDecoder(strict=BadBool()).decode(u'""')
+ with self.assertRaises(ZeroDivisionError):
+ decoder.JSONDecoder(strict=BadBool()).decode(u'{}')
+
+class TestEncode(TestCase):
+ @skip_if_speedups_missing
def test_make_encoder(self):
self.assertRaises(
TypeError,
@@ -37,3 +58,35 @@ class TestDecode(TestCase):
"\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75"),
None
)
+
+ @skip_if_speedups_missing
+ def test_bad_bool_args(self):
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(skipkeys=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(ensure_ascii=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(check_circular=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(allow_nan=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(sort_keys=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(use_decimal=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(namedtuple_as_object=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(tuple_as_array=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(bigint_as_string=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(for_json=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(ignore_nan=BadBool()).encode({})
+ with self.assertRaises(ZeroDivisionError):
+ encoder.JSONEncoder(iterable_as_array=BadBool()).encode({})
+
+ @skip_if_speedups_missing
+ def test_int_as_string_bitcount_overflow(self):
+ with self.assertRaises((TypeError, OverflowError)):
+ encoder.JSONEncoder(int_as_string_bitcount=2**32+31).encode(0)