summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2020-11-03 18:06:37 +0200
committerGitHub <noreply@github.com>2020-11-03 18:06:37 +0200
commitd62b0ee88b20e5946fe49f0ba533b3e547e4d4f1 (patch)
tree476a0581d9a2a595e337b3bd2982a759c34f6b0e
parent4c83c0444c68b89b051f7ef8d8eb1a2276439d78 (diff)
parentd02ca96090ea2fed97b7789a855668c1ddc98294 (diff)
downloadnumpy-d62b0ee88b20e5946fe49f0ba533b3e547e4d4f1.tar.gz
Merge pull request #17295 from seberg/issue-17294
BUG,ENH: fix pickling user-scalars by allowing non-format buffer export
-rw-r--r--numpy/core/src/multiarray/buffer.c66
-rw-r--r--numpy/core/src/multiarray/scalarapi.c4
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src45
-rw-r--r--numpy/core/src/umath/_rational_tests.c.src2
-rw-r--r--numpy/core/tests/test_multiarray.py16
5 files changed, 106 insertions, 27 deletions
diff --git a/numpy/core/src/multiarray/buffer.c b/numpy/core/src/multiarray/buffer.c
index e676682de..3b3bba663 100644
--- a/numpy/core/src/multiarray/buffer.c
+++ b/numpy/core/src/multiarray/buffer.c
@@ -456,7 +456,7 @@ static PyObject *_buffer_info_cache = NULL;
/* Fill in the info structure */
static _buffer_info_t*
-_buffer_info_new(PyObject *obj, npy_bool f_contiguous)
+_buffer_info_new(PyObject *obj, int flags)
{
/*
* Note that the buffer info is cached as PyLongObjects making them appear
@@ -514,6 +514,7 @@ _buffer_info_new(PyObject *obj, npy_bool f_contiguous)
* (This is unnecessary, but has no effect in the case where
* NPY_RELAXED_STRIDES CHECKING is disabled.)
*/
+ int f_contiguous = (flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS;
if (PyArray_IS_C_CONTIGUOUS(arr) && !(
f_contiguous && PyArray_IS_F_CONTIGUOUS(arr))) {
Py_ssize_t sd = PyArray_ITEMSIZE(arr);
@@ -547,16 +548,20 @@ _buffer_info_new(PyObject *obj, npy_bool f_contiguous)
}
/* Fill in format */
- err = _buffer_format_string(descr, &fmt, obj, NULL, NULL);
- Py_DECREF(descr);
- if (err != 0) {
- goto fail;
+ if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
+ err = _buffer_format_string(descr, &fmt, obj, NULL, NULL);
+ Py_DECREF(descr);
+ if (err != 0) {
+ goto fail;
+ }
+ if (_append_char(&fmt, '\0') < 0) {
+ goto fail;
+ }
+ info->format = fmt.s;
}
- if (_append_char(&fmt, '\0') < 0) {
- goto fail;
+ else {
+ info->format = NULL;
}
- info->format = fmt.s;
-
return info;
fail:
@@ -572,9 +577,10 @@ _buffer_info_cmp(_buffer_info_t *a, _buffer_info_t *b)
Py_ssize_t c;
int k;
- c = strcmp(a->format, b->format);
- if (c != 0) return c;
-
+ if (a->format != NULL && b->format != NULL) {
+ c = strcmp(a->format, b->format);
+ if (c != 0) return c;
+ }
c = a->ndim - b->ndim;
if (c != 0) return c;
@@ -599,7 +605,7 @@ _buffer_info_free(_buffer_info_t *info)
/* Get buffer info from the global dictionary */
static _buffer_info_t*
-_buffer_get_info(PyObject *obj, npy_bool f_contiguous)
+_buffer_get_info(PyObject *obj, int flags)
{
PyObject *key = NULL, *item_list = NULL, *item = NULL;
_buffer_info_t *info = NULL, *old_info = NULL;
@@ -612,7 +618,7 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
}
/* Compute information */
- info = _buffer_info_new(obj, f_contiguous);
+ info = _buffer_info_new(obj, flags);
if (info == NULL) {
return NULL;
}
@@ -630,11 +636,9 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
if (item_list_length > 0) {
item = PyList_GetItem(item_list, item_list_length - 1);
old_info = (_buffer_info_t*)PyLong_AsVoidPtr(item);
- if (_buffer_info_cmp(info, old_info) == 0) {
- _buffer_info_free(info);
- info = old_info;
- }
- else {
+ if (_buffer_info_cmp(info, old_info) != 0) {
+ old_info = NULL; /* Can't use this one, but possibly next */
+
if (item_list_length > 1 && info->ndim > 1) {
/*
* Some arrays are C- and F-contiguous and if they have more
@@ -648,12 +652,26 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
*/
item = PyList_GetItem(item_list, item_list_length - 2);
old_info = (_buffer_info_t*)PyLong_AsVoidPtr(item);
- if (_buffer_info_cmp(info, old_info) == 0) {
- _buffer_info_free(info);
- info = old_info;
+ if (_buffer_info_cmp(info, old_info) != 0) {
+ old_info = NULL;
}
}
}
+
+ if (old_info != NULL) {
+ /*
+ * The two info->format are considered equal if one of them
+ * has no format set (meaning the format is arbitrary and can
+ * be modified). If the new info has a format, but we reuse
+ * the old one, this transfers the ownership to the old one.
+ */
+ if (old_info->format == NULL) {
+ old_info->format = info->format;
+ info->format = NULL;
+ }
+ _buffer_info_free(info);
+ info = old_info;
+ }
}
}
else {
@@ -760,7 +778,7 @@ array_getbuffer(PyObject *obj, Py_buffer *view, int flags)
}
/* Fill in information */
- info = _buffer_get_info(obj, (flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS);
+ info = _buffer_get_info(obj, flags);
if (info == NULL) {
goto fail;
}
@@ -825,7 +843,7 @@ void_getbuffer(PyObject *self, Py_buffer *view, int flags)
}
/* Fill in information */
- info = _buffer_get_info(self, 0);
+ info = _buffer_get_info(self, flags);
if (info == NULL) {
goto fail;
}
diff --git a/numpy/core/src/multiarray/scalarapi.c b/numpy/core/src/multiarray/scalarapi.c
index f610ad468..0e93cbbe9 100644
--- a/numpy/core/src/multiarray/scalarapi.c
+++ b/numpy/core/src/multiarray/scalarapi.c
@@ -35,7 +35,7 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr)
{
int type_num;
int align;
- npy_intp memloc;
+ uintptr_t memloc;
if (descr == NULL) {
descr = PyArray_DescrFromScalar(scalar);
type_num = descr->type_num;
@@ -168,7 +168,7 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr)
* Use the alignment flag to figure out where the data begins
* after a PyObject_HEAD
*/
- memloc = (npy_intp)scalar;
+ memloc = (uintptr_t)scalar;
memloc += sizeof(PyObject);
/* now round-up to the nearest alignment value */
align = descr->alignment;
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 1a50927a8..d2ae6ce31 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -2383,6 +2383,50 @@ static PySequenceMethods voidtype_as_sequence = {
};
+/*
+ * This function implements simple buffer export for user defined subclasses
+ * of `np.generic`. All other scalar types override the buffer export.
+ */
+static int
+gentype_arrtype_getbuffer(PyObject *self, Py_buffer *view, int flags)
+{
+ if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
+ PyErr_Format(PyExc_TypeError,
+ "NumPy scalar %R can only exported as a buffer without format.",
+ self);
+ return -1;
+ }
+ PyArray_Descr *descr = PyArray_DescrFromScalar(self);
+ if (descr == NULL) {
+ return -1;
+ }
+ if (!PyDataType_ISUSERDEF(descr)) {
+ /* This path would also reject the (hopefully) impossible "object" */
+ PyErr_Format(PyExc_TypeError,
+ "user-defined scalar %R registered for built-in dtype %S? "
+ "This should be impossible.",
+ self, descr);
+ return -1;
+ }
+ view->ndim = 0;
+ view->len = descr->elsize;
+ view->itemsize = descr->elsize;
+ view->shape = NULL;
+ view->strides = NULL;
+ view->suboffsets = NULL;
+ Py_INCREF(self);
+ view->obj = self;
+ view->buf = scalar_value(self, descr);
+ Py_DECREF(descr);
+ view->format = NULL;
+ return 0;
+}
+
+
+static PyBufferProcs gentype_arrtype_as_buffer = {
+ .bf_getbuffer = (getbufferproc)gentype_arrtype_getbuffer,
+};
+
/**begin repeat
* #name = bool, byte, short, int, long, longlong, ubyte, ushort, uint, ulong,
@@ -3794,6 +3838,7 @@ initialize_numeric_types(void)
PyGenericArrType_Type.tp_alloc = gentype_alloc;
PyGenericArrType_Type.tp_free = (freefunc)gentype_free;
PyGenericArrType_Type.tp_richcompare = gentype_richcompare;
+ PyGenericArrType_Type.tp_as_buffer = &gentype_arrtype_as_buffer;
PyBoolArrType_Type.tp_as_number = &bool_arrtype_as_number;
/*
diff --git a/numpy/core/src/umath/_rational_tests.c.src b/numpy/core/src/umath/_rational_tests.c.src
index 08c259d98..7b1e5627a 100644
--- a/numpy/core/src/umath/_rational_tests.c.src
+++ b/numpy/core/src/umath/_rational_tests.c.src
@@ -663,7 +663,7 @@ static PyGetSetDef pyrational_getset[] = {
static PyTypeObject PyRational_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
- "rational", /* tp_name */
+ "numpy.core._rational_tests.rational", /* tp_name */
sizeof(PyRational), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index d46e4ce9b..291e8ba8e 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -22,6 +22,7 @@ from decimal import Decimal
import numpy as np
import numpy.core._multiarray_tests as _multiarray_tests
+from numpy.core._rational_tests import rational
from numpy.testing import (
assert_, assert_raises, assert_warns, assert_equal, assert_almost_equal,
assert_array_equal, assert_raises_regex, assert_array_almost_equal,
@@ -7143,6 +7144,21 @@ class TestNewBufferProtocol:
_multiarray_tests.get_buffer_info,
np.arange(5)[::2], ('SIMPLE',))
+ @pytest.mark.parametrize(["obj", "error"], [
+ pytest.param(np.array([1, 2], dtype=rational), ValueError, id="array"),
+ pytest.param(rational(1, 2), TypeError, id="scalar")])
+ def test_export_and_pickle_user_dtype(self, obj, error):
+ # User dtypes should export successfully when FORMAT was not requested.
+ with pytest.raises(error):
+ _multiarray_tests.get_buffer_info(obj, ("STRIDED", "FORMAT"))
+
+ _multiarray_tests.get_buffer_info(obj, ("STRIDED",))
+
+ # This is currently also necessary to implement pickling:
+ pickle_obj = pickle.dumps(obj)
+ res = pickle.loads(pickle_obj)
+ assert_array_equal(res, obj)
+
def test_padding(self):
for j in range(8):
x = np.array([(1,), (2,)], dtype={'f0': (int, j)})