summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-04-25 11:07:06 -0600
committerGitHub <noreply@github.com>2018-04-25 11:07:06 -0600
commitf81f3839d879491dec2545e6576db0cb0b214fae (patch)
treef03d7e23bcfcedf70e56b324d0ba364356f2cebb
parentd55743b0fc542a6e2004173e41732635fa6f505a (diff)
parentb8d6d4d35feef83844ac2b067c3d293990bff4f5 (diff)
downloadnumpy-f81f3839d879491dec2545e6576db0cb0b214fae.tar.gz
Merge pull request #10911 from toslunar/fix-10899-einsum-float16
BUG: Fix casting between npy_half and float in einsum
-rw-r--r--numpy/core/src/multiarray/einsum.c.src6
-rw-r--r--numpy/core/tests/test_einsum.py10
2 files changed, 13 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 5dbc30aa9..470a5fff9 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -591,7 +591,7 @@ finish_after_unrolled_loop:
accum += @from@(data0[@i@]) * @from@(data1[@i@]);
/**end repeat2**/
case 0:
- *(@type@ *)dataptr[2] += @to@(accum);
+ *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + accum);
return;
}
@@ -749,7 +749,7 @@ finish_after_unrolled_loop:
accum += @from@(data1[@i@]);
/**end repeat2**/
case 0:
- *(@type@ *)dataptr[2] += @to@(value0 * accum);
+ *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + value0 * accum);
return;
}
@@ -848,7 +848,7 @@ finish_after_unrolled_loop:
accum += @from@(data0[@i@]);
/**end repeat2**/
case 0:
- *(@type@ *)dataptr[2] += @to@(accum * value1);
+ *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + accum * value1);
return;
}
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 792b9e0a2..104dd1986 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -502,6 +502,16 @@ class TestEinSum(object):
optimize=optimize),
np.full((1, 5), 5))
+ # Cases which were failing (gh-10899)
+ x = np.eye(2, dtype=dtype)
+ y = np.ones(2, dtype=dtype)
+ assert_array_equal(np.einsum("ji,i->", x, y, optimize=optimize),
+ [2.]) # contig_contig_outstride0_two
+ assert_array_equal(np.einsum("i,ij->", y, x, optimize=optimize),
+ [2.]) # stride0_contig_outstride0_two
+ assert_array_equal(np.einsum("ij,i->", x, y, optimize=optimize),
+ [2.]) # contig_stride0_outstride0_two
+
def test_einsum_sums_int8(self):
self.check_einsum_sums('i1')