diff options
author | Matti Picus <matti.picus@gmail.com> | 2020-11-03 18:06:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-11-03 18:06:37 +0200 |
commit | d62b0ee88b20e5946fe49f0ba533b3e547e4d4f1 (patch) | |
tree | 476a0581d9a2a595e337b3bd2982a759c34f6b0e | |
parent | 4c83c0444c68b89b051f7ef8d8eb1a2276439d78 (diff) | |
parent | d02ca96090ea2fed97b7789a855668c1ddc98294 (diff) | |
download | numpy-d62b0ee88b20e5946fe49f0ba533b3e547e4d4f1.tar.gz |
Merge pull request #17295 from seberg/issue-17294
BUG,ENH: fix pickling user-scalars by allowing non-format buffer export
-rw-r--r-- | numpy/core/src/multiarray/buffer.c | 66 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalarapi.c | 4 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 45 | ||||
-rw-r--r-- | numpy/core/src/umath/_rational_tests.c.src | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 16 |
5 files changed, 106 insertions, 27 deletions
diff --git a/numpy/core/src/multiarray/buffer.c b/numpy/core/src/multiarray/buffer.c index e676682de..3b3bba663 100644 --- a/numpy/core/src/multiarray/buffer.c +++ b/numpy/core/src/multiarray/buffer.c @@ -456,7 +456,7 @@ static PyObject *_buffer_info_cache = NULL; /* Fill in the info structure */ static _buffer_info_t* -_buffer_info_new(PyObject *obj, npy_bool f_contiguous) +_buffer_info_new(PyObject *obj, int flags) { /* * Note that the buffer info is cached as PyLongObjects making them appear @@ -514,6 +514,7 @@ _buffer_info_new(PyObject *obj, npy_bool f_contiguous) * (This is unnecessary, but has no effect in the case where * NPY_RELAXED_STRIDES CHECKING is disabled.) */ + int f_contiguous = (flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS; if (PyArray_IS_C_CONTIGUOUS(arr) && !( f_contiguous && PyArray_IS_F_CONTIGUOUS(arr))) { Py_ssize_t sd = PyArray_ITEMSIZE(arr); @@ -547,16 +548,20 @@ _buffer_info_new(PyObject *obj, npy_bool f_contiguous) } /* Fill in format */ - err = _buffer_format_string(descr, &fmt, obj, NULL, NULL); - Py_DECREF(descr); - if (err != 0) { - goto fail; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + err = _buffer_format_string(descr, &fmt, obj, NULL, NULL); + Py_DECREF(descr); + if (err != 0) { + goto fail; + } + if (_append_char(&fmt, '\0') < 0) { + goto fail; + } + info->format = fmt.s; } - if (_append_char(&fmt, '\0') < 0) { - goto fail; + else { + info->format = NULL; } - info->format = fmt.s; - return info; fail: @@ -572,9 +577,10 @@ _buffer_info_cmp(_buffer_info_t *a, _buffer_info_t *b) Py_ssize_t c; int k; - c = strcmp(a->format, b->format); - if (c != 0) return c; - + if (a->format != NULL && b->format != NULL) { + c = strcmp(a->format, b->format); + if (c != 0) return c; + } c = a->ndim - b->ndim; if (c != 0) return c; @@ -599,7 +605,7 @@ _buffer_info_free(_buffer_info_t *info) /* Get buffer info from the global dictionary */ static _buffer_info_t* -_buffer_get_info(PyObject *obj, npy_bool f_contiguous) +_buffer_get_info(PyObject *obj, int flags) { PyObject *key = NULL, *item_list = NULL, *item = NULL; _buffer_info_t *info = NULL, *old_info = NULL; @@ -612,7 +618,7 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous) } /* Compute information */ - info = _buffer_info_new(obj, f_contiguous); + info = _buffer_info_new(obj, flags); if (info == NULL) { return NULL; } @@ -630,11 +636,9 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous) if (item_list_length > 0) { item = PyList_GetItem(item_list, item_list_length - 1); old_info = (_buffer_info_t*)PyLong_AsVoidPtr(item); - if (_buffer_info_cmp(info, old_info) == 0) { - _buffer_info_free(info); - info = old_info; - } - else { + if (_buffer_info_cmp(info, old_info) != 0) { + old_info = NULL; /* Can't use this one, but possibly next */ + if (item_list_length > 1 && info->ndim > 1) { /* * Some arrays are C- and F-contiguous and if they have more @@ -648,12 +652,26 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous) */ item = PyList_GetItem(item_list, item_list_length - 2); old_info = (_buffer_info_t*)PyLong_AsVoidPtr(item); - if (_buffer_info_cmp(info, old_info) == 0) { - _buffer_info_free(info); - info = old_info; + if (_buffer_info_cmp(info, old_info) != 0) { + old_info = NULL; } } } + + if (old_info != NULL) { + /* + * The two info->format are considered equal if one of them + * has no format set (meaning the format is arbitrary and can + * be modified). If the new info has a format, but we reuse + * the old one, this transfers the ownership to the old one. + */ + if (old_info->format == NULL) { + old_info->format = info->format; + info->format = NULL; + } + _buffer_info_free(info); + info = old_info; + } } } else { @@ -760,7 +778,7 @@ array_getbuffer(PyObject *obj, Py_buffer *view, int flags) } /* Fill in information */ - info = _buffer_get_info(obj, (flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS); + info = _buffer_get_info(obj, flags); if (info == NULL) { goto fail; } @@ -825,7 +843,7 @@ void_getbuffer(PyObject *self, Py_buffer *view, int flags) } /* Fill in information */ - info = _buffer_get_info(self, 0); + info = _buffer_get_info(self, flags); if (info == NULL) { goto fail; } diff --git a/numpy/core/src/multiarray/scalarapi.c b/numpy/core/src/multiarray/scalarapi.c index f610ad468..0e93cbbe9 100644 --- a/numpy/core/src/multiarray/scalarapi.c +++ b/numpy/core/src/multiarray/scalarapi.c @@ -35,7 +35,7 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr) { int type_num; int align; - npy_intp memloc; + uintptr_t memloc; if (descr == NULL) { descr = PyArray_DescrFromScalar(scalar); type_num = descr->type_num; @@ -168,7 +168,7 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr) * Use the alignment flag to figure out where the data begins * after a PyObject_HEAD */ - memloc = (npy_intp)scalar; + memloc = (uintptr_t)scalar; memloc += sizeof(PyObject); /* now round-up to the nearest alignment value */ align = descr->alignment; diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 1a50927a8..d2ae6ce31 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -2383,6 +2383,50 @@ static PySequenceMethods voidtype_as_sequence = { }; +/* + * This function implements simple buffer export for user defined subclasses + * of `np.generic`. All other scalar types override the buffer export. + */ +static int +gentype_arrtype_getbuffer(PyObject *self, Py_buffer *view, int flags) +{ + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + PyErr_Format(PyExc_TypeError, + "NumPy scalar %R can only exported as a buffer without format.", + self); + return -1; + } + PyArray_Descr *descr = PyArray_DescrFromScalar(self); + if (descr == NULL) { + return -1; + } + if (!PyDataType_ISUSERDEF(descr)) { + /* This path would also reject the (hopefully) impossible "object" */ + PyErr_Format(PyExc_TypeError, + "user-defined scalar %R registered for built-in dtype %S? " + "This should be impossible.", + self, descr); + return -1; + } + view->ndim = 0; + view->len = descr->elsize; + view->itemsize = descr->elsize; + view->shape = NULL; + view->strides = NULL; + view->suboffsets = NULL; + Py_INCREF(self); + view->obj = self; + view->buf = scalar_value(self, descr); + Py_DECREF(descr); + view->format = NULL; + return 0; +} + + +static PyBufferProcs gentype_arrtype_as_buffer = { + .bf_getbuffer = (getbufferproc)gentype_arrtype_getbuffer, +}; + /**begin repeat * #name = bool, byte, short, int, long, longlong, ubyte, ushort, uint, ulong, @@ -3794,6 +3838,7 @@ initialize_numeric_types(void) PyGenericArrType_Type.tp_alloc = gentype_alloc; PyGenericArrType_Type.tp_free = (freefunc)gentype_free; PyGenericArrType_Type.tp_richcompare = gentype_richcompare; + PyGenericArrType_Type.tp_as_buffer = &gentype_arrtype_as_buffer; PyBoolArrType_Type.tp_as_number = &bool_arrtype_as_number; /* diff --git a/numpy/core/src/umath/_rational_tests.c.src b/numpy/core/src/umath/_rational_tests.c.src index 08c259d98..7b1e5627a 100644 --- a/numpy/core/src/umath/_rational_tests.c.src +++ b/numpy/core/src/umath/_rational_tests.c.src @@ -663,7 +663,7 @@ static PyGetSetDef pyrational_getset[] = { static PyTypeObject PyRational_Type = { PyVarObject_HEAD_INIT(NULL, 0) - "rational", /* tp_name */ + "numpy.core._rational_tests.rational", /* tp_name */ sizeof(PyRational), /* tp_basicsize */ 0, /* tp_itemsize */ 0, /* tp_dealloc */ diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index d46e4ce9b..291e8ba8e 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -22,6 +22,7 @@ from decimal import Decimal import numpy as np import numpy.core._multiarray_tests as _multiarray_tests +from numpy.core._rational_tests import rational from numpy.testing import ( assert_, assert_raises, assert_warns, assert_equal, assert_almost_equal, assert_array_equal, assert_raises_regex, assert_array_almost_equal, @@ -7143,6 +7144,21 @@ class TestNewBufferProtocol: _multiarray_tests.get_buffer_info, np.arange(5)[::2], ('SIMPLE',)) + @pytest.mark.parametrize(["obj", "error"], [ + pytest.param(np.array([1, 2], dtype=rational), ValueError, id="array"), + pytest.param(rational(1, 2), TypeError, id="scalar")]) + def test_export_and_pickle_user_dtype(self, obj, error): + # User dtypes should export successfully when FORMAT was not requested. + with pytest.raises(error): + _multiarray_tests.get_buffer_info(obj, ("STRIDED", "FORMAT")) + + _multiarray_tests.get_buffer_info(obj, ("STRIDED",)) + + # This is currently also necessary to implement pickling: + pickle_obj = pickle.dumps(obj) + res = pickle.loads(pickle_obj) + assert_array_equal(res, obj) + def test_padding(self): for j in range(8): x = np.array([(1,), (2,)], dtype={'f0': (int, j)}) |