diff options
author | mattip <matti.picus@gmail.com> | 2018-06-08 16:06:50 -0700 |
---|---|---|
committer | mattip <matti.picus@gmail.com> | 2018-06-11 11:31:06 -0700 |
commit | 01a0971afc00b5ab610d3cb72d1111452c663bf2 (patch) | |
tree | eb3141452eae072914f5960f3787808fe3b0e329 | |
parent | 1b920805704095fde1b8f6ad7ff81a62f5176dd6 (diff) | |
download | numpy-01a0971afc00b5ab610d3cb72d1111452c663bf2.tar.gz |
BUG: einsum needs to check overlap on an out argument
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 28 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 11 |
2 files changed, 27 insertions, 12 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 69833bee6..33184d99a 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -2499,7 +2499,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS]; int *op_axes[NPY_MAXARGS]; - npy_uint32 op_flags[NPY_MAXARGS]; + npy_uint32 iter_flags, op_flags[NPY_MAXARGS]; NpyIter *iter; sum_of_products_fn sop; @@ -2745,19 +2745,23 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, NPY_ITER_ALIGNED| NPY_ITER_ALLOCATE| NPY_ITER_NO_BROADCAST; + iter_flags = NPY_ITER_EXTERNAL_LOOP| + NPY_ITER_BUFFERED| + NPY_ITER_DELAY_BUFALLOC| + NPY_ITER_GROWINNER| + NPY_ITER_REDUCE_OK| + NPY_ITER_REFS_OK| + NPY_ITER_ZEROSIZE_OK; + if (out != NULL) { + iter_flags |= NPY_ITER_COPY_IF_OVERLAP; + } + if (dtype == NULL) { + iter_flags |= NPY_ITER_COMMON_DTYPE; + } /* Allocate the iterator */ - iter = NpyIter_AdvancedNew(nop+1, op, NPY_ITER_EXTERNAL_LOOP| - ((dtype != NULL) ? 0 : NPY_ITER_COMMON_DTYPE)| - NPY_ITER_BUFFERED| - NPY_ITER_DELAY_BUFALLOC| - NPY_ITER_GROWINNER| - NPY_ITER_REDUCE_OK| - NPY_ITER_REFS_OK| - NPY_ITER_ZEROSIZE_OK, - order, casting, - op_flags, op_dtypes, - ndim_iter, op_axes, NULL, 0); + iter = NpyIter_AdvancedNew(nop+1, op, iter_flags, order, casting, op_flags, + op_dtypes, ndim_iter, op_axes, NULL, 0); if (iter == NULL) { goto fail; diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index 63e75ff7a..647738831 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -961,3 +961,14 @@ class TestEinSumPath(object): for sp in itertools.product(['', ' '], repeat=4): # no error for any spacing np.einsum('{}...a{}->{}...a{}'.format(*sp), arr) + +def test_overlap(): + a = np.arange(9, dtype=int).reshape(3, 3) + b = np.arange(9, dtype=int).reshape(3, 3) + d = np.dot(a, b) + # sanity check + c = np.einsum('ij,jk->ik', a, b) + assert_equal(c, d) + #gh-10080, out overlaps one of the operands + c = np.einsum('ij,jk->ik', a, b, out=b) + assert_equal(c, d) |