diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-10-04 09:05:49 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-04 09:05:49 -0600 |
commit | c386502c5d3774346dcf420528789df98066c7d6 (patch) | |
tree | 42389e87428775356710739a216426be5b3cace6 | |
parent | 86953fd4e123bba98c1aa7ba34b2fa0634fcdcce (diff) | |
parent | 8196c2a46fb621580a53ad5f7b2bd08cd154e870 (diff) | |
download | numpy-c386502c5d3774346dcf420528789df98066c7d6.tar.gz |
Merge pull request #20018 from WarrenWeckesser/show-shapes
ENH: core: More informative error message for broadcast(*args)
-rw-r--r-- | numpy/core/src/multiarray/iterators.c | 36 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 6 |
2 files changed, 38 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c index 36bfaa7cf..f959162fd 100644 --- a/numpy/core/src/multiarray/iterators.c +++ b/numpy/core/src/multiarray/iterators.c @@ -1124,6 +1124,35 @@ NPY_NO_EXPORT PyTypeObject PyArrayIter_Type = { /** END of Array Iterator **/ + +static int +set_shape_mismatch_exception(PyArrayMultiIterObject *mit, int i1, int i2) +{ + PyObject *shape1, *shape2, *msg; + + shape1 = PyObject_GetAttrString((PyObject *) mit->iters[i1]->ao, "shape"); + if (shape1 == NULL) { + return -1; + } + shape2 = PyObject_GetAttrString((PyObject *) mit->iters[i2]->ao, "shape"); + if (shape2 == NULL) { + Py_DECREF(shape1); + return -1; + } + msg = PyUnicode_FromFormat("shape mismatch: objects cannot be broadcast " + "to a single shape. Mismatch is between arg %d " + "with shape %S and arg %d with shape %S.", + i1, shape1, i2, shape2); + Py_DECREF(shape1); + Py_DECREF(shape2); + if (msg == NULL) { + return -1; + } + PyErr_SetObject(PyExc_ValueError, msg); + Py_DECREF(msg); + return 0; +} + /* Adjust dimensionality and strides for index object iterators --- i.e. broadcast */ @@ -1132,6 +1161,7 @@ NPY_NO_EXPORT int PyArray_Broadcast(PyArrayMultiIterObject *mit) { int i, nd, k, j; + int src_iter = -1; /* Initializing avoids a compiler warning. */ npy_intp tmp; PyArrayIterObject *it; @@ -1155,12 +1185,10 @@ PyArray_Broadcast(PyArrayMultiIterObject *mit) } if (mit->dimensions[i] == 1) { mit->dimensions[i] = tmp; + src_iter = j; } else if (mit->dimensions[i] != tmp) { - PyErr_SetString(PyExc_ValueError, - "shape mismatch: objects" \ - " cannot be broadcast" \ - " to a single shape"); + set_shape_mismatch_exception(mit, src_iter, j); return -1; } } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 4510333a1..e36f76c53 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -3511,6 +3511,12 @@ class TestBroadcast: assert_raises(ValueError, np.broadcast, 1, **{'x': 1}) + def test_shape_mismatch_error_message(self): + with pytest.raises(ValueError, match=r"arg 0 with shape \(1, 3\) and " + r"arg 2 with shape \(2,\)"): + np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7]) + + class TestKeepdims: class sub_array(np.ndarray): |