summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2018-06-08 16:06:50 -0700
committermattip <matti.picus@gmail.com>2018-06-11 11:31:06 -0700
commit01a0971afc00b5ab610d3cb72d1111452c663bf2 (patch)
treeeb3141452eae072914f5960f3787808fe3b0e329
parent1b920805704095fde1b8f6ad7ff81a62f5176dd6 (diff)
downloadnumpy-01a0971afc00b5ab610d3cb72d1111452c663bf2.tar.gz
BUG: einsum needs to check overlap on an out argument
-rw-r--r--numpy/core/src/multiarray/einsum.c.src28
-rw-r--r--numpy/core/tests/test_einsum.py11
2 files changed, 27 insertions, 12 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 69833bee6..33184d99a 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -2499,7 +2499,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
int *op_axes[NPY_MAXARGS];
- npy_uint32 op_flags[NPY_MAXARGS];
+ npy_uint32 iter_flags, op_flags[NPY_MAXARGS];
NpyIter *iter;
sum_of_products_fn sop;
@@ -2745,19 +2745,23 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
NPY_ITER_ALIGNED|
NPY_ITER_ALLOCATE|
NPY_ITER_NO_BROADCAST;
+ iter_flags = NPY_ITER_EXTERNAL_LOOP|
+ NPY_ITER_BUFFERED|
+ NPY_ITER_DELAY_BUFALLOC|
+ NPY_ITER_GROWINNER|
+ NPY_ITER_REDUCE_OK|
+ NPY_ITER_REFS_OK|
+ NPY_ITER_ZEROSIZE_OK;
+ if (out != NULL) {
+ iter_flags |= NPY_ITER_COPY_IF_OVERLAP;
+ }
+ if (dtype == NULL) {
+ iter_flags |= NPY_ITER_COMMON_DTYPE;
+ }
/* Allocate the iterator */
- iter = NpyIter_AdvancedNew(nop+1, op, NPY_ITER_EXTERNAL_LOOP|
- ((dtype != NULL) ? 0 : NPY_ITER_COMMON_DTYPE)|
- NPY_ITER_BUFFERED|
- NPY_ITER_DELAY_BUFALLOC|
- NPY_ITER_GROWINNER|
- NPY_ITER_REDUCE_OK|
- NPY_ITER_REFS_OK|
- NPY_ITER_ZEROSIZE_OK,
- order, casting,
- op_flags, op_dtypes,
- ndim_iter, op_axes, NULL, 0);
+ iter = NpyIter_AdvancedNew(nop+1, op, iter_flags, order, casting, op_flags,
+ op_dtypes, ndim_iter, op_axes, NULL, 0);
if (iter == NULL) {
goto fail;
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 63e75ff7a..647738831 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -961,3 +961,14 @@ class TestEinSumPath(object):
for sp in itertools.product(['', ' '], repeat=4):
# no error for any spacing
np.einsum('{}...a{}->{}...a{}'.format(*sp), arr)
+
+def test_overlap():
+ a = np.arange(9, dtype=int).reshape(3, 3)
+ b = np.arange(9, dtype=int).reshape(3, 3)
+ d = np.dot(a, b)
+ # sanity check
+ c = np.einsum('ij,jk->ik', a, b)
+ assert_equal(c, d)
+ #gh-10080, out overlaps one of the operands
+ c = np.einsum('ij,jk->ik', a, b, out=b)
+ assert_equal(c, d)