diff options
author | Allan Haldane <allan.haldane@gmail.com> | 2018-02-08 00:03:41 -0500 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2018-02-16 13:30:42 -0700 |
commit | 6a05feae16cba75250dc22d03c72f541d7c65c26 (patch) | |
tree | b29fefca4c11613a83e7ccde528ec4cb6e5c9b7f | |
parent | 7311b961a6827abdee8179cf40f3ab4a2b682408 (diff) | |
download | numpy-6a05feae16cba75250dc22d03c72f541d7c65c26.tar.gz |
BUG: infinite recursion in str of 0d subclasses
Fixes #10360
-rw-r--r-- | numpy/core/arrayprint.py | 20 | ||||
-rw-r--r-- | numpy/core/tests/test_arrayprint.py | 49 |
2 files changed, 63 insertions, 6 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 987589dbe..381c1074d 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -435,14 +435,17 @@ def _recursive_guard(fillvalue='...'): # gracefully handle recursive calls, when object arrays contain themselves @_recursive_guard() def _array2string(a, options, separator=' ', prefix=""): - # The formatter __init__s cannot deal with subclasses yet - data = asarray(a) + # The formatter __init__s in _get_format_function cannot deal with + # subclasses yet, and we also need to avoid recursion issues in + # _formatArray with subclasses which return 0d arrays in place of scalars + a = asarray(a) if a.size > options['threshold']: summary_insert = "..." - data = _leading_trailing(data, options['edgeitems']) + data = _leading_trailing(a, options['edgeitems']) else: summary_insert = "" + data = a # find the right formatting function for the array format_function = _get_format_function(data, **options) @@ -468,7 +471,7 @@ def array2string(a, max_line_width=None, precision=None, Parameters ---------- - a : ndarray + a : array_like Input array. max_line_width : int, optional The maximum number of columns the string should span. Newline @@ -730,7 +733,7 @@ def _formatArray(a, format_function, line_width, next_line_prefix, if show_summary: if legacy == '1.13': - # trailing space, fixed number of newlines, and fixed separator + # trailing space, fixed nbr of newlines, and fixed separator s += hanging_indent + summary_insert + ", \n" else: s += hanging_indent + summary_insert + line_sep @@ -1380,6 +1383,8 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): return arr_str + spacer + dtype_str +_guarded_str = _recursive_guard()(str) + def array_str(a, max_line_width=None, precision=None, suppress_small=None): """ Return a string representation of the data in an array. @@ -1422,7 +1427,10 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None): # so floats are not truncated by `precision`, and strings are not wrapped # in quotes. So we return the str of the scalar value. if a.shape == (): - return str(a[()]) + # obtain a scalar and call str on it, avoiding problems for subclasses + # for which indexing with () returns a 0d instead of a scalar by using + # ndarray's getindex. Also guard against recursive 0d object arrays. + return _guarded_str(np.ndarray.__getitem__(a, ())) return array2string(a, max_line_width, precision, suppress_small, ' ', "") diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py index 3b6bc7b0f..f70b6a333 100644 --- a/numpy/core/tests/test_arrayprint.py +++ b/numpy/core/tests/test_arrayprint.py @@ -34,6 +34,55 @@ class TestArrayRepr(object): " [(1,), (1,)]], dtype=[('a', '<i4')])" ) + def test_0d_object_subclass(self): + # make sure that subclasses which return 0ds instead + # of scalars don't cause infinite recursion in str + class sub(np.ndarray): + def __new__(cls, inp): + obj = np.asarray(inp).view(cls) + return obj + + def __getitem__(self, ind): + ret = super(sub, self).__getitem__(ind) + return sub(ret) + + x = sub(1) + assert_equal(repr(x), 'sub(1)') + assert_equal(str(x), '1') + + x = sub([1, 1]) + assert_equal(repr(x), 'sub([1, 1])') + assert_equal(str(x), '[1 1]') + + # check it works properly with object arrays too + x = sub(None) + assert_equal(repr(x), 'sub(None, dtype=object)') + assert_equal(str(x), 'None') + + # plus recursive object arrays (even depth > 1) + y = sub(None) + x[()] = y + y[()] = x + assert_equal(repr(x), + 'sub(sub(sub(..., dtype=object), dtype=object), dtype=object)') + assert_equal(str(x), '...') + + # nested 0d-subclass-object + x = sub(None) + x[()] = sub(None) + assert_equal(repr(x), 'sub(sub(None, dtype=object), dtype=object)') + assert_equal(str(x), 'None') + + # test that object + subclass is OK: + x = sub([None, None]) + assert_equal(repr(x), 'sub([None, None], dtype=object)') + assert_equal(str(x), '[None None]') + + x = sub([None, sub([None, None])]) + assert_equal(repr(x), + 'sub([None, sub([None, None], dtype=object)], dtype=object)') + assert_equal(str(x), '[None sub([None, None], dtype=object)]') + def test_self_containing(self): arr0d = np.array(None) arr0d[()] = arr0d |