summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2021-03-21 08:09:43 +0200
committerGitHub <noreply@github.com>2021-03-21 08:09:43 +0200
commit980f8a0663fa00d3d0e0ad8cba406eb4e29902c9 (patch)
treeca5e73b74681009f1f49779b1d99703ead04cbcd
parent9e47444aa66ae055c3ef5a01d579d2eb52606f20 (diff)
parentd7af05a849b3b81922ec3da988494a70a875ca91 (diff)
downloadnumpy-980f8a0663fa00d3d0e0ad8cba406eb4e29902c9.tar.gz
Merge pull request #18593 from seberg/stop-lying-about-binop-inputs
MAINT: Do not claim input to binops is `self` (array object)
-rw-r--r--numpy/core/src/multiarray/arrayobject.c23
-rw-r--r--numpy/core/src/multiarray/calculation.c5
-rw-r--r--numpy/core/src/multiarray/number.c68
-rw-r--r--numpy/core/src/multiarray/number.h2
-rw-r--r--numpy/core/src/multiarray/temp_elide.c13
-rw-r--r--numpy/core/src/multiarray/temp_elide.h2
-rw-r--r--numpy/core/tests/test_multiarray.py9
7 files changed, 68 insertions, 54 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index 1326140d5..e7fbb88cd 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -1356,11 +1356,13 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
switch (cmp_op) {
case Py_LT:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
- result = PyArray_GenericBinaryFunction(self, other, n_ops.less);
+ result = PyArray_GenericBinaryFunction(
+ (PyObject *)self, other, n_ops.less);
break;
case Py_LE:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
- result = PyArray_GenericBinaryFunction(self, other, n_ops.less_equal);
+ result = PyArray_GenericBinaryFunction(
+ (PyObject *)self, other, n_ops.less_equal);
break;
case Py_EQ:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
@@ -1410,9 +1412,8 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return result;
}
- result = PyArray_GenericBinaryFunction(self,
- (PyObject *)other,
- n_ops.equal);
+ result = PyArray_GenericBinaryFunction(
+ (PyObject *)self, (PyObject *)other, n_ops.equal);
break;
case Py_NE:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
@@ -1462,18 +1463,18 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return result;
}
- result = PyArray_GenericBinaryFunction(self, (PyObject *)other,
- n_ops.not_equal);
+ result = PyArray_GenericBinaryFunction(
+ (PyObject *)self, (PyObject *)other, n_ops.not_equal);
break;
case Py_GT:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
- result = PyArray_GenericBinaryFunction(self, other,
- n_ops.greater);
+ result = PyArray_GenericBinaryFunction(
+ (PyObject *)self, other, n_ops.greater);
break;
case Py_GE:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
- result = PyArray_GenericBinaryFunction(self, other,
- n_ops.greater_equal);
+ result = PyArray_GenericBinaryFunction(
+ (PyObject *)self, other, n_ops.greater_equal);
break;
default:
Py_INCREF(Py_NotImplemented);
diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c
index 43d88271b..7308c6b71 100644
--- a/numpy/core/src/multiarray/calculation.c
+++ b/numpy/core/src/multiarray/calculation.c
@@ -423,7 +423,8 @@ __New_PyArray_Std(PyArrayObject *self, int axis, int rtype, PyArrayObject *out,
return NULL;
}
arr2 = (PyArrayObject *)PyArray_EnsureAnyArray(
- PyArray_GenericBinaryFunction(arr1, obj3, n_ops.multiply));
+ PyArray_GenericBinaryFunction((PyObject *)arr1, obj3,
+ n_ops.multiply));
Py_DECREF(arr1);
Py_DECREF(obj3);
if (arr2 == NULL) {
@@ -1211,7 +1212,7 @@ PyArray_Conjugate(PyArrayObject *self, PyArrayObject *out)
n_ops.conjugate);
}
else {
- return PyArray_GenericBinaryFunction(self,
+ return PyArray_GenericBinaryFunction((PyObject *)self,
(PyObject *)out,
n_ops.conjugate);
}
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 78f21db4f..7e9b93782 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -262,7 +262,7 @@ PyArray_GenericAccumulateFunction(PyArrayObject *m1, PyObject *op, int axis,
NPY_NO_EXPORT PyObject *
-PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op)
+PyArray_GenericBinaryFunction(PyObject *m1, PyObject *m2, PyObject *op)
{
return PyObject_CallFunctionObjArgs(op, m1, m2, NULL);
}
@@ -287,7 +287,7 @@ PyArray_GenericInplaceUnaryFunction(PyArrayObject *m1, PyObject *op)
}
static PyObject *
-array_add(PyArrayObject *m1, PyObject *m2)
+array_add(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -299,7 +299,7 @@ array_add(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_subtract(PyArrayObject *m1, PyObject *m2)
+array_subtract(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -311,7 +311,7 @@ array_subtract(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_multiply(PyArrayObject *m1, PyObject *m2)
+array_multiply(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -323,14 +323,14 @@ array_multiply(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_remainder(PyArrayObject *m1, PyObject *m2)
+array_remainder(PyObject *m1, PyObject *m2)
{
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_remainder, array_remainder);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder);
}
static PyObject *
-array_divmod(PyArrayObject *m1, PyObject *m2)
+array_divmod(PyObject *m1, PyObject *m2)
{
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_divmod, array_divmod);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.divmod);
@@ -338,7 +338,7 @@ array_divmod(PyArrayObject *m1, PyObject *m2)
/* Need this to be version dependent on account of the slot check */
static PyObject *
-array_matrix_multiply(PyArrayObject *m1, PyObject *m2)
+array_matrix_multiply(PyObject *m1, PyObject *m2)
{
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_matrix_multiply, array_matrix_multiply);
return PyArray_GenericBinaryFunction(m1, m2, n_ops.matmul);
@@ -442,15 +442,16 @@ is_scalar_with_conversion(PyObject *o2, double* out_exponent)
* the result is in value (can be NULL if an error occurred)
*/
static int
-fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace,
+fast_scalar_power(PyObject *o1, PyObject *o2, int inplace,
PyObject **value)
{
double exponent;
NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */
- if (PyArray_Check(a1) &&
- !PyArray_ISOBJECT(a1) &&
+ if (PyArray_Check(o1) &&
+ !PyArray_ISOBJECT((PyArrayObject *)o1) &&
((kind=is_scalar_with_conversion(o2, &exponent))>0)) {
+ PyArrayObject *a1 = (PyArrayObject *)o1;
PyObject *fastop = NULL;
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
if (exponent == 1.0) {
@@ -514,7 +515,7 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace,
}
static PyObject *
-array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo)
+array_power(PyObject *a1, PyObject *o2, PyObject *modulo)
{
PyObject *value = NULL;
@@ -605,7 +606,7 @@ array_invert(PyArrayObject *m1)
}
static PyObject *
-array_left_shift(PyArrayObject *m1, PyObject *m2)
+array_left_shift(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -617,7 +618,7 @@ array_left_shift(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_right_shift(PyArrayObject *m1, PyObject *m2)
+array_right_shift(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -629,7 +630,7 @@ array_right_shift(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_bitwise_and(PyArrayObject *m1, PyObject *m2)
+array_bitwise_and(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -641,7 +642,7 @@ array_bitwise_and(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_bitwise_or(PyArrayObject *m1, PyObject *m2)
+array_bitwise_or(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -653,7 +654,7 @@ array_bitwise_or(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_bitwise_xor(PyArrayObject *m1, PyObject *m2)
+array_bitwise_xor(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -704,7 +705,7 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
INPLACE_GIVE_UP_IF_NEEDED(
a1, o2, nb_inplace_power, array_inplace_power);
- if (fast_scalar_power(a1, o2, 1, &value) != 0) {
+ if (fast_scalar_power((PyObject *)a1, o2, 1, &value) != 0) {
value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
}
return value;
@@ -751,7 +752,7 @@ array_inplace_bitwise_xor(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_floor_divide(PyArrayObject *m1, PyObject *m2)
+array_floor_divide(PyObject *m1, PyObject *m2)
{
PyObject *res;
@@ -763,13 +764,14 @@ array_floor_divide(PyArrayObject *m1, PyObject *m2)
}
static PyObject *
-array_true_divide(PyArrayObject *m1, PyObject *m2)
+array_true_divide(PyObject *m1, PyObject *m2)
{
PyObject *res;
+ PyArrayObject *a1 = (PyArrayObject *)m1;
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_true_divide, array_true_divide);
if (PyArray_CheckExact(m1) &&
- (PyArray_ISFLOAT(m1) || PyArray_ISCOMPLEX(m1)) &&
+ (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) &&
try_binary_elide(m1, m2, &array_inplace_true_divide, &res, 0)) {
return res;
}
@@ -900,22 +902,22 @@ array_index(PyArrayObject *v)
NPY_NO_EXPORT PyNumberMethods array_as_number = {
- .nb_add = (binaryfunc)array_add,
- .nb_subtract = (binaryfunc)array_subtract,
- .nb_multiply = (binaryfunc)array_multiply,
- .nb_remainder = (binaryfunc)array_remainder,
- .nb_divmod = (binaryfunc)array_divmod,
+ .nb_add = array_add,
+ .nb_subtract = array_subtract,
+ .nb_multiply = array_multiply,
+ .nb_remainder = array_remainder,
+ .nb_divmod = array_divmod,
.nb_power = (ternaryfunc)array_power,
.nb_negative = (unaryfunc)array_negative,
.nb_positive = (unaryfunc)array_positive,
.nb_absolute = (unaryfunc)array_absolute,
.nb_bool = (inquiry)_array_nonzero,
.nb_invert = (unaryfunc)array_invert,
- .nb_lshift = (binaryfunc)array_left_shift,
- .nb_rshift = (binaryfunc)array_right_shift,
- .nb_and = (binaryfunc)array_bitwise_and,
- .nb_xor = (binaryfunc)array_bitwise_xor,
- .nb_or = (binaryfunc)array_bitwise_or,
+ .nb_lshift = array_left_shift,
+ .nb_rshift = array_right_shift,
+ .nb_and = array_bitwise_and,
+ .nb_xor = array_bitwise_xor,
+ .nb_or = array_bitwise_or,
.nb_int = (unaryfunc)array_int,
.nb_float = (unaryfunc)array_float,
@@ -932,11 +934,11 @@ NPY_NO_EXPORT PyNumberMethods array_as_number = {
.nb_inplace_xor = (binaryfunc)array_inplace_bitwise_xor,
.nb_inplace_or = (binaryfunc)array_inplace_bitwise_or,
- .nb_floor_divide = (binaryfunc)array_floor_divide,
- .nb_true_divide = (binaryfunc)array_true_divide,
+ .nb_floor_divide = array_floor_divide,
+ .nb_true_divide = array_true_divide,
.nb_inplace_floor_divide = (binaryfunc)array_inplace_floor_divide,
.nb_inplace_true_divide = (binaryfunc)array_inplace_true_divide,
- .nb_matrix_multiply = (binaryfunc)array_matrix_multiply,
+ .nb_matrix_multiply = array_matrix_multiply,
.nb_inplace_matrix_multiply = (binaryfunc)array_inplace_matrix_multiply,
};
diff --git a/numpy/core/src/multiarray/number.h b/numpy/core/src/multiarray/number.h
index 643241b3d..4f426f964 100644
--- a/numpy/core/src/multiarray/number.h
+++ b/numpy/core/src/multiarray/number.h
@@ -56,7 +56,7 @@ NPY_NO_EXPORT PyObject *
_PyArray_GetNumericOps(void);
NPY_NO_EXPORT PyObject *
-PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op);
+PyArray_GenericBinaryFunction(PyObject *m1, PyObject *m2, PyObject *op);
NPY_NO_EXPORT PyObject *
PyArray_GenericUnaryFunction(PyArrayObject *m1, PyObject *op);
diff --git a/numpy/core/src/multiarray/temp_elide.c b/numpy/core/src/multiarray/temp_elide.c
index b19dee418..2b4621744 100644
--- a/numpy/core/src/multiarray/temp_elide.c
+++ b/numpy/core/src/multiarray/temp_elide.c
@@ -274,13 +274,14 @@ check_callers(int * cannot)
* "cannot" is set to true if it cannot be done even with swapped arguments
*/
static int
-can_elide_temp(PyArrayObject * alhs, PyObject * orhs, int * cannot)
+can_elide_temp(PyObject *olhs, PyObject *orhs, int *cannot)
{
/*
* to be a candidate the array needs to have reference count 1, be an exact
* array of a basic type, own its data and size larger than threshold
*/
- if (Py_REFCNT(alhs) != 1 || !PyArray_CheckExact(alhs) ||
+ PyArrayObject *alhs = (PyArrayObject *)olhs;
+ if (Py_REFCNT(olhs) != 1 || !PyArray_CheckExact(olhs) ||
!PyArray_ISNUMBER(alhs) ||
!PyArray_CHKFLAGS(alhs, NPY_ARRAY_OWNDATA) ||
!PyArray_ISWRITEABLE(alhs) ||
@@ -328,22 +329,22 @@ can_elide_temp(PyArrayObject * alhs, PyObject * orhs, int * cannot)
* try eliding a binary op, if commutative is true also try swapped arguments
*/
NPY_NO_EXPORT int
-try_binary_elide(PyArrayObject * m1, PyObject * m2,
+try_binary_elide(PyObject * m1, PyObject * m2,
PyObject * (inplace_op)(PyArrayObject * m1, PyObject * m2),
PyObject ** res, int commutative)
{
/* set when no elision can be done independent of argument order */
int cannot = 0;
if (can_elide_temp(m1, m2, &cannot)) {
- *res = inplace_op(m1, m2);
+ *res = inplace_op((PyArrayObject *)m1, m2);
#if NPY_ELIDE_DEBUG != 0
puts("elided temporary in binary op");
#endif
return 1;
}
else if (commutative && !cannot) {
- if (can_elide_temp((PyArrayObject *)m2, (PyObject *)m1, &cannot)) {
- *res = inplace_op((PyArrayObject *)m2, (PyObject *)m1);
+ if (can_elide_temp(m2, m1, &cannot)) {
+ *res = inplace_op((PyArrayObject *)m2, m1);
#if NPY_ELIDE_DEBUG != 0
puts("elided temporary in commutative binary op");
#endif
diff --git a/numpy/core/src/multiarray/temp_elide.h b/numpy/core/src/multiarray/temp_elide.h
index d073adf28..206bb0253 100644
--- a/numpy/core/src/multiarray/temp_elide.h
+++ b/numpy/core/src/multiarray/temp_elide.h
@@ -8,7 +8,7 @@ NPY_NO_EXPORT int
can_elide_temp_unary(PyArrayObject * m1);
NPY_NO_EXPORT int
-try_binary_elide(PyArrayObject * m1, PyObject * m2,
+try_binary_elide(PyObject * m1, PyObject * m2,
PyObject * (inplace_op)(PyArrayObject * m1, PyObject * m2),
PyObject ** res, int commutative);
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 269e144d9..7656b4d0a 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -3417,6 +3417,15 @@ class TestMethods:
assert_raises(TypeError, lambda: a.conj())
assert_raises(TypeError, lambda: a.conjugate())
+ def test_conjugate_out(self):
+ # Minimal test for the out argument being passed on correctly
+ # NOTE: The ability to pass `out` is currently undocumented!
+ a = np.array([1-1j, 1+1j, 23+23.0j])
+ out = np.empty_like(a)
+ res = a.conjugate(out)
+ assert res is out
+ assert_array_equal(out, a.conjugate())
+
def test__complex__(self):
dtypes = ['i1', 'i2', 'i4', 'i8',
'u1', 'u2', 'u4', 'u8',