summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2020-12-11 08:34:50 +0200
committerGitHub <noreply@github.com>2020-12-11 08:34:50 +0200
commit9e26d1d2be7a961a16f8fa9ff7820c33b25415e2 (patch)
tree08abaed7fdb7f1ed4002b208f1938887f69255d0
parent91d9bbebc159ada6cccd0e1fcec1926e26b19c6f (diff)
parentf921f0d13bb34d82503bfa2b3bff24d095bb9385 (diff)
downloadnumpy-9e26d1d2be7a961a16f8fa9ff7820c33b25415e2.tar.gz
Merge pull request #17782 from Qiyu8/einsum-muladd
SIMD: Optimize the performance of einsum's submodule multiply by using universal intrinsics
-rw-r--r--numpy/core/src/multiarray/einsum_sumprod.c.src183
1 files changed, 82 insertions, 101 deletions
diff --git a/numpy/core/src/multiarray/einsum_sumprod.c.src b/numpy/core/src/multiarray/einsum_sumprod.c.src
index c58e74287..caba0e00a 100644
--- a/numpy/core/src/multiarray/einsum_sumprod.c.src
+++ b/numpy/core/src/multiarray/einsum_sumprod.c.src
@@ -17,7 +17,8 @@
#include "einsum_sumprod.h"
#include "einsum_debug.h"
-
+#include "simd/simd.h"
+#include "common.h"
#ifdef NPY_HAVE_SSE_INTRINSICS
#define EINSUM_USE_SSE1 1
@@ -41,6 +42,13 @@
#define EINSUM_IS_SSE_ALIGNED(x) ((((npy_intp)x)&0xf) == 0)
+// ARM/Neon don't have instructions for aligned memory access
+#ifdef NPY_HAVE_NEON
+ #define EINSUM_IS_ALIGNED(x) 0
+#else
+ #define EINSUM_IS_ALIGNED(x) npy_is_aligned(x, NPY_SIMD_WIDTH)
+#endif
+
/**********************************************/
/**begin repeat
@@ -56,6 +64,10 @@
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
* npy_float, npy_float, npy_double, npy_longdouble,
* npy_float, npy_double, npy_longdouble#
+ * #sfx = s8, s16, s32, long, s64,
+ * u8, u16, u32, ulong, u64,
+ * half, f32, f64, longdouble,
+ * f32, f64, clongdouble#
* #to = ,,,,,
* ,,,,,
* npy_float_to_half,,,,
@@ -76,6 +88,10 @@
* 0*5,
* 0,0,1,0,
* 0*3#
+ * #NPYV_CHK = 0*5,
+ * 0*5,
+ * 0, NPY_SIMD, NPY_SIMD_F64, 0,
+ * 0*3#
*/
/**begin repeat1
@@ -250,115 +266,80 @@ static void
@type@ *data0 = (@type@ *)dataptr[0];
@type@ *data1 = (@type@ *)dataptr[1];
@type@ *data_out = (@type@ *)dataptr[2];
-
-#if EINSUM_USE_SSE1 && @float32@
- __m128 a, b;
-#elif EINSUM_USE_SSE2 && @float64@
- __m128d a, b;
-#endif
-
NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_two (%d)\n",
(int)count);
-
-/* This is placed before the main loop to make small counts faster */
-finish_after_unrolled_loop:
- switch (count) {
-/**begin repeat2
- * #i = 6, 5, 4, 3, 2, 1, 0#
- */
- case @i@+1:
- data_out[@i@] = @to@(@from@(data0[@i@]) *
- @from@(data1[@i@]) +
- @from@(data_out[@i@]));
-/**end repeat2**/
- case 0:
- return;
- }
-
-#if EINSUM_USE_SSE1 && @float32@
+ // NPYV check for @type@
+#if @NPYV_CHK@
/* Use aligned instructions if possible */
- if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1) &&
- EINSUM_IS_SSE_ALIGNED(data_out)) {
- /* Unroll the loop by 8 */
- while (count >= 8) {
- count -= 8;
-
-/**begin repeat2
- * #i = 0, 4#
- */
- a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@));
- b = _mm_add_ps(a, _mm_load_ps(data_out+@i@));
- _mm_store_ps(data_out+@i@, b);
-/**end repeat2**/
- data0 += 8;
- data1 += 8;
- data_out += 8;
+ const int is_aligned = EINSUM_IS_ALIGNED(data0) && EINSUM_IS_ALIGNED(data1) &&
+ EINSUM_IS_ALIGNED(data_out);
+ const int vstep = npyv_nlanes_@sfx@;
+
+ /**begin repeat2
+ * #cond = if(is_aligned), else#
+ * #ld = loada, load#
+ * #st = storea, store#
+ */
+ @cond@ {
+ const npy_intp vstepx4 = vstep * 4;
+ for (; count >= vstepx4; count -= vstepx4, data0 += vstepx4, data1 += vstepx4, data_out += vstepx4) {
+ /**begin repeat3
+ * #i = 0, 1, 2, 3#
+ */
+ npyv_@sfx@ a@i@ = npyv_@ld@_@sfx@(data0 + vstep * @i@);
+ npyv_@sfx@ b@i@ = npyv_@ld@_@sfx@(data1 + vstep * @i@);
+ npyv_@sfx@ c@i@ = npyv_@ld@_@sfx@(data_out + vstep * @i@);
+ /**end repeat3**/
+ /**begin repeat3
+ * #i = 0, 1, 2, 3#
+ */
+ npyv_@sfx@ abc@i@ = npyv_muladd_@sfx@(a@i@, b@i@, c@i@);
+ /**end repeat3**/
+ /**begin repeat3
+ * #i = 0, 1, 2, 3#
+ */
+ npyv_@st@_@sfx@(data_out + vstep * @i@, abc@i@);
+ /**end repeat3**/
}
-
- /* Finish off the loop */
- goto finish_after_unrolled_loop;
}
-#elif EINSUM_USE_SSE2 && @float64@
- /* Use aligned instructions if possible */
- if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1) &&
- EINSUM_IS_SSE_ALIGNED(data_out)) {
- /* Unroll the loop by 8 */
- while (count >= 8) {
- count -= 8;
-
-/**begin repeat2
- * #i = 0, 2, 4, 6#
- */
- a = _mm_mul_pd(_mm_load_pd(data0+@i@), _mm_load_pd(data1+@i@));
- b = _mm_add_pd(a, _mm_load_pd(data_out+@i@));
- _mm_store_pd(data_out+@i@, b);
-/**end repeat2**/
- data0 += 8;
- data1 += 8;
- data_out += 8;
- }
-
- /* Finish off the loop */
- goto finish_after_unrolled_loop;
+ /**end repeat2**/
+ for (; count > 0; count -= vstep, data0 += vstep, data1 += vstep, data_out += vstep) {
+ npyv_@sfx@ a = npyv_load_tillz_@sfx@(data0, count);
+ npyv_@sfx@ b = npyv_load_tillz_@sfx@(data1, count);
+ npyv_@sfx@ c = npyv_load_tillz_@sfx@(data_out, count);
+ npyv_store_till_@sfx@(data_out, count, npyv_muladd_@sfx@(a, b, c));
}
-#endif
-
- /* Unroll the loop by 8 */
- while (count >= 8) {
- count -= 8;
-
-#if EINSUM_USE_SSE1 && @float32@
-/**begin repeat2
- * #i = 0, 4#
- */
- a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@));
- b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@));
- _mm_storeu_ps(data_out+@i@, b);
-/**end repeat2**/
-#elif EINSUM_USE_SSE2 && @float64@
-/**begin repeat2
- * #i = 0, 2, 4, 6#
- */
- a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), _mm_loadu_pd(data1+@i@));
- b = _mm_add_pd(a, _mm_loadu_pd(data_out+@i@));
- _mm_storeu_pd(data_out+@i@, b);
-/**end repeat2**/
+ npyv_cleanup();
#else
-/**begin repeat2
- * #i = 0, 1, 2, 3, 4, 5, 6, 7#
- */
- data_out[@i@] = @to@(@from@(data0[@i@]) *
- @from@(data1[@i@]) +
- @from@(data_out[@i@]));
-/**end repeat2**/
-#endif
- data0 += 8;
- data1 += 8;
- data_out += 8;
+#ifndef NPY_DISABLE_OPTIMIZATION
+ for (; count >= 4; count -= 4, data0 += 4, data1 += 4, data_out += 4) {
+ /**begin repeat2
+ * #i = 0, 1, 2, 3#
+ */
+ const @type@ a@i@ = @from@(data0[@i@]);
+ const @type@ b@i@ = @from@(data1[@i@]);
+ const @type@ c@i@ = @from@(data_out[@i@]);
+ /**end repeat2**/
+ /**begin repeat2
+ * #i = 0, 1, 2, 3#
+ */
+ const @type@ abc@i@ = a@i@ * b@i@ + c@i@;
+ /**end repeat2**/
+ /**begin repeat2
+ * #i = 0, 1, 2, 3#
+ */
+ data_out[@i@] = @to@(abc@i@);
+ /**end repeat2**/
}
+#endif // !NPY_DISABLE_OPTIMIZATION
+ for (; count > 0; --count, ++data0, ++data1, ++data_out) {
+ const @type@ a = @from@(*data0);
+ const @type@ b = @from@(*data1);
+ const @type@ c = @from@(*data_out);
+ *data_out = @to@(a * b + c);
+ }
+#endif // NPYV check for @type@
- /* Finish off the loop */
- goto finish_after_unrolled_loop;
}
/* Some extra specializations for the two operand case */