summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-10-04 09:05:49 -0600
committerGitHub <noreply@github.com>2021-10-04 09:05:49 -0600
commitc386502c5d3774346dcf420528789df98066c7d6 (patch)
tree42389e87428775356710739a216426be5b3cace6
parent86953fd4e123bba98c1aa7ba34b2fa0634fcdcce (diff)
parent8196c2a46fb621580a53ad5f7b2bd08cd154e870 (diff)
downloadnumpy-c386502c5d3774346dcf420528789df98066c7d6.tar.gz
Merge pull request #20018 from WarrenWeckesser/show-shapes
ENH: core: More informative error message for broadcast(*args)
-rw-r--r--numpy/core/src/multiarray/iterators.c36
-rw-r--r--numpy/core/tests/test_numeric.py6
2 files changed, 38 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 36bfaa7cf..f959162fd 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1124,6 +1124,35 @@ NPY_NO_EXPORT PyTypeObject PyArrayIter_Type = {
/** END of Array Iterator **/
+
+static int
+set_shape_mismatch_exception(PyArrayMultiIterObject *mit, int i1, int i2)
+{
+ PyObject *shape1, *shape2, *msg;
+
+ shape1 = PyObject_GetAttrString((PyObject *) mit->iters[i1]->ao, "shape");
+ if (shape1 == NULL) {
+ return -1;
+ }
+ shape2 = PyObject_GetAttrString((PyObject *) mit->iters[i2]->ao, "shape");
+ if (shape2 == NULL) {
+ Py_DECREF(shape1);
+ return -1;
+ }
+ msg = PyUnicode_FromFormat("shape mismatch: objects cannot be broadcast "
+ "to a single shape. Mismatch is between arg %d "
+ "with shape %S and arg %d with shape %S.",
+ i1, shape1, i2, shape2);
+ Py_DECREF(shape1);
+ Py_DECREF(shape2);
+ if (msg == NULL) {
+ return -1;
+ }
+ PyErr_SetObject(PyExc_ValueError, msg);
+ Py_DECREF(msg);
+ return 0;
+}
+
/* Adjust dimensionality and strides for index object iterators
--- i.e. broadcast
*/
@@ -1132,6 +1161,7 @@ NPY_NO_EXPORT int
PyArray_Broadcast(PyArrayMultiIterObject *mit)
{
int i, nd, k, j;
+ int src_iter = -1; /* Initializing avoids a compiler warning. */
npy_intp tmp;
PyArrayIterObject *it;
@@ -1155,12 +1185,10 @@ PyArray_Broadcast(PyArrayMultiIterObject *mit)
}
if (mit->dimensions[i] == 1) {
mit->dimensions[i] = tmp;
+ src_iter = j;
}
else if (mit->dimensions[i] != tmp) {
- PyErr_SetString(PyExc_ValueError,
- "shape mismatch: objects" \
- " cannot be broadcast" \
- " to a single shape");
+ set_shape_mismatch_exception(mit, src_iter, j);
return -1;
}
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 4510333a1..e36f76c53 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -3511,6 +3511,12 @@ class TestBroadcast:
assert_raises(ValueError, np.broadcast, 1, **{'x': 1})
+ def test_shape_mismatch_error_message(self):
+ with pytest.raises(ValueError, match=r"arg 0 with shape \(1, 3\) and "
+ r"arg 2 with shape \(2,\)"):
+ np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7])
+
+
class TestKeepdims:
class sub_array(np.ndarray):