summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2023-02-20 21:19:43 +0200
committerGitHub <noreply@github.com>2023-02-20 21:19:43 +0200
commitb657a75ca651368b2d29c221287046333c3f7580 (patch)
tree0a2a757c1368876c11591a887b1c1041399bc182 /numpy
parentf6eaca8c63a06b588e73d7d189604f978e6dffe6 (diff)
parent3d37f944826bf56d743626d1691937e7cb961a49 (diff)
downloadnumpy-b657a75ca651368b2d29c221287046333c3f7580.tar.gz
Merge pull request #23248 from seberg/fast-take-no-refcnt
ENH: Avoid use of item XINCREF and DECREF in fasttake
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/item_selection.c81
1 files changed, 55 insertions, 26 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 44c553133..15433b587 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -17,6 +17,7 @@
#include "multiarraymodule.h"
#include "common.h"
+#include "dtype_transfer.h"
#include "arrayobject.h"
#include "ctors.h"
#include "lowlevel_strided_loops.h"
@@ -39,7 +40,26 @@ npy_fasttake_impl(
PyArray_Descr *dtype, int axis)
{
NPY_BEGIN_THREADS_DEF;
- NPY_BEGIN_THREADS_DESCR(dtype);
+
+ NPY_cast_info cast_info;
+ NPY_ARRAYMETHOD_FLAGS flags;
+ NPY_cast_info_init(&cast_info);
+
+ if (!needs_refcounting) {
+ /* if "refcounting" is not needed memcpy is safe for a simple copy */
+ NPY_BEGIN_THREADS;
+ }
+ else {
+ if (PyArray_GetDTypeTransferFunction(
+ 1, itemsize, itemsize, dtype, dtype, 0,
+ &cast_info, &flags) < 0) {
+ return -1;
+ }
+ if (!(flags & NPY_METH_REQUIRES_PYAPI)) {
+ NPY_BEGIN_THREADS;
+ }
+ }
+
switch (clipmode) {
case NPY_RAISE:
for (npy_intp i = 0; i < n; i++) {
@@ -47,22 +67,23 @@ npy_fasttake_impl(
npy_intp tmp = indices[j];
if (check_and_adjust_index(&tmp, max_item, axis,
_save) < 0) {
- return -1;
+ goto fail;
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
- for (npy_intp k = 0; k < nelem; k++) {
- PyArray_Item_INCREF(tmp_src, dtype);
- PyArray_Item_XDECREF(dest, dtype);
- memmove(dest, tmp_src, itemsize);
- dest += itemsize;
- tmp_src += itemsize;
+ char *data[2] = {tmp_src, dest};
+ npy_intp strides[2] = {itemsize, itemsize};
+ if (cast_info.func(
+ &cast_info.context, data, &nelem, strides,
+ cast_info.auxdata) < 0) {
+ NPY_END_THREADS;
+ goto fail;
}
}
else {
- memmove(dest, tmp_src, chunk);
- dest += chunk;
+ memcpy(dest, tmp_src, chunk);
}
+ dest += chunk;
}
src += chunk*max_item;
}
@@ -83,18 +104,19 @@ npy_fasttake_impl(
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
- for (npy_intp k = 0; k < nelem; k++) {
- PyArray_Item_INCREF(tmp_src, dtype);
- PyArray_Item_XDECREF(dest, dtype);
- memmove(dest, tmp_src, itemsize);
- dest += itemsize;
- tmp_src += itemsize;
+ char *data[2] = {tmp_src, dest};
+ npy_intp strides[2] = {itemsize, itemsize};
+ if (cast_info.func(
+ &cast_info.context, data, &nelem, strides,
+ cast_info.auxdata) < 0) {
+ NPY_END_THREADS;
+ goto fail;
}
}
else {
- memmove(dest, tmp_src, chunk);
- dest += chunk;
+ memcpy(dest, tmp_src, chunk);
}
+ dest += chunk;
}
src += chunk*max_item;
}
@@ -111,18 +133,19 @@ npy_fasttake_impl(
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
- for (npy_intp k = 0; k < nelem; k++) {
- PyArray_Item_INCREF(tmp_src, dtype);
- PyArray_Item_XDECREF(dest, dtype);
- memmove(dest, tmp_src, itemsize);
- dest += itemsize;
- tmp_src += itemsize;
+ char *data[2] = {tmp_src, dest};
+ npy_intp strides[2] = {itemsize, itemsize};
+ if (cast_info.func(
+ &cast_info.context, data, &nelem, strides,
+ cast_info.auxdata) < 0) {
+ NPY_END_THREADS;
+ goto fail;
}
}
else {
- memmove(dest, tmp_src, chunk);
- dest += chunk;
+ memcpy(dest, tmp_src, chunk);
}
+ dest += chunk;
}
src += chunk*max_item;
}
@@ -130,7 +153,13 @@ npy_fasttake_impl(
}
NPY_END_THREADS;
+ NPY_cast_info_xfree(&cast_info);
return 0;
+
+ fail:
+ /* NPY_END_THREADS already ensured. */
+ NPY_cast_info_xfree(&cast_info);
+ return -1;
}