summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAllan Haldane <allan.haldane@gmail.com>2018-02-08 00:03:41 -0500
committerCharles Harris <charlesr.harris@gmail.com>2018-02-16 13:30:42 -0700
commit6a05feae16cba75250dc22d03c72f541d7c65c26 (patch)
treeb29fefca4c11613a83e7ccde528ec4cb6e5c9b7f
parent7311b961a6827abdee8179cf40f3ab4a2b682408 (diff)
downloadnumpy-6a05feae16cba75250dc22d03c72f541d7c65c26.tar.gz
BUG: infinite recursion in str of 0d subclasses
Fixes #10360
-rw-r--r--numpy/core/arrayprint.py20
-rw-r--r--numpy/core/tests/test_arrayprint.py49
2 files changed, 63 insertions, 6 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index 987589dbe..381c1074d 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -435,14 +435,17 @@ def _recursive_guard(fillvalue='...'):
# gracefully handle recursive calls, when object arrays contain themselves
@_recursive_guard()
def _array2string(a, options, separator=' ', prefix=""):
- # The formatter __init__s cannot deal with subclasses yet
- data = asarray(a)
+ # The formatter __init__s in _get_format_function cannot deal with
+ # subclasses yet, and we also need to avoid recursion issues in
+ # _formatArray with subclasses which return 0d arrays in place of scalars
+ a = asarray(a)
if a.size > options['threshold']:
summary_insert = "..."
- data = _leading_trailing(data, options['edgeitems'])
+ data = _leading_trailing(a, options['edgeitems'])
else:
summary_insert = ""
+ data = a
# find the right formatting function for the array
format_function = _get_format_function(data, **options)
@@ -468,7 +471,7 @@ def array2string(a, max_line_width=None, precision=None,
Parameters
----------
- a : ndarray
+ a : array_like
Input array.
max_line_width : int, optional
The maximum number of columns the string should span. Newline
@@ -730,7 +733,7 @@ def _formatArray(a, format_function, line_width, next_line_prefix,
if show_summary:
if legacy == '1.13':
- # trailing space, fixed number of newlines, and fixed separator
+ # trailing space, fixed nbr of newlines, and fixed separator
s += hanging_indent + summary_insert + ", \n"
else:
s += hanging_indent + summary_insert + line_sep
@@ -1380,6 +1383,8 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
return arr_str + spacer + dtype_str
+_guarded_str = _recursive_guard()(str)
+
def array_str(a, max_line_width=None, precision=None, suppress_small=None):
"""
Return a string representation of the data in an array.
@@ -1422,7 +1427,10 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None):
# so floats are not truncated by `precision`, and strings are not wrapped
# in quotes. So we return the str of the scalar value.
if a.shape == ():
- return str(a[()])
+ # obtain a scalar and call str on it, avoiding problems for subclasses
+ # for which indexing with () returns a 0d instead of a scalar by using
+ # ndarray's getindex. Also guard against recursive 0d object arrays.
+ return _guarded_str(np.ndarray.__getitem__(a, ()))
return array2string(a, max_line_width, precision, suppress_small, ' ', "")
diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py
index 3b6bc7b0f..f70b6a333 100644
--- a/numpy/core/tests/test_arrayprint.py
+++ b/numpy/core/tests/test_arrayprint.py
@@ -34,6 +34,55 @@ class TestArrayRepr(object):
" [(1,), (1,)]], dtype=[('a', '<i4')])"
)
+ def test_0d_object_subclass(self):
+ # make sure that subclasses which return 0ds instead
+ # of scalars don't cause infinite recursion in str
+ class sub(np.ndarray):
+ def __new__(cls, inp):
+ obj = np.asarray(inp).view(cls)
+ return obj
+
+ def __getitem__(self, ind):
+ ret = super(sub, self).__getitem__(ind)
+ return sub(ret)
+
+ x = sub(1)
+ assert_equal(repr(x), 'sub(1)')
+ assert_equal(str(x), '1')
+
+ x = sub([1, 1])
+ assert_equal(repr(x), 'sub([1, 1])')
+ assert_equal(str(x), '[1 1]')
+
+ # check it works properly with object arrays too
+ x = sub(None)
+ assert_equal(repr(x), 'sub(None, dtype=object)')
+ assert_equal(str(x), 'None')
+
+ # plus recursive object arrays (even depth > 1)
+ y = sub(None)
+ x[()] = y
+ y[()] = x
+ assert_equal(repr(x),
+ 'sub(sub(sub(..., dtype=object), dtype=object), dtype=object)')
+ assert_equal(str(x), '...')
+
+ # nested 0d-subclass-object
+ x = sub(None)
+ x[()] = sub(None)
+ assert_equal(repr(x), 'sub(sub(None, dtype=object), dtype=object)')
+ assert_equal(str(x), 'None')
+
+ # test that object + subclass is OK:
+ x = sub([None, None])
+ assert_equal(repr(x), 'sub([None, None], dtype=object)')
+ assert_equal(str(x), '[None None]')
+
+ x = sub([None, sub([None, None])])
+ assert_equal(repr(x),
+ 'sub([None, sub([None, None], dtype=object)], dtype=object)')
+ assert_equal(str(x), '[None sub([None, None], dtype=object)]')
+
def test_self_containing(self):
arr0d = np.array(None)
arr0d[()] = arr0d