diff options
| author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-12-21 16:46:25 -0600 |
|---|---|---|
| committer | Charles Harris <charlesr.harris@gmail.com> | 2020-12-23 15:04:03 -0700 |
| commit | 4bd709c0fde357811c63bcd5387e7b86f1780751 (patch) | |
| tree | 3f7f7bb0a712d82df9b1740b658079dbe5f887ca /numpy/core | |
| parent | 5b9e5ca696337282a7dae2c9c556ccf614187c61 (diff) | |
| download | numpy-4bd709c0fde357811c63bcd5387e7b86f1780751.tar.gz | |
BUG: Fix concatenation when the output is "S" or "U"
Previously, the dtype was used, this now assumes that we want to
cast to a string of (unknown) length. This is a simplified version
of what happens in `np.array()` or `arr.astype()` (it does never
inspect the values, e.g. for object casts).
This is more complex as I would like, and with the refactor of
ResultType and similar can be cleaned up a bit more hopefully.
Note that currently, object to "S" or "U" casts simply return
length 64 strings, but with the new version, this will be an error
(although the error message probably needs improvement).
This is a behaviour inherited from other places however.
Diffstat (limited to 'numpy/core')
| -rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 67 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/convert_datatype.h | 4 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 57 | ||||
| -rw-r--r-- | numpy/core/tests/test_shape_base.py | 28 |
4 files changed, 116 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index f9dd35a73..5d5b69bd5 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -871,6 +871,73 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType) } +/* + * Helper to find the target descriptor for multiple arrays given an input + * one that may be a DType class (e.g. "U" or "S"). + * Works with arrays, since that is what `concatenate` works with. However, + * unlike `np.array(...)` or `arr.astype()` we will never inspect the array's + * content, which means that object arrays can only be cast to strings if a + * fixed width is provided (same for string -> generic datetime). + * + * As this function uses `PyArray_ExtractDTypeAndDescriptor`, it should + * eventually be refactored to move the step to an earlier point. + */ +NPY_NO_EXPORT PyArray_Descr * +PyArray_FindConcatenationDescriptor( + npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype) +{ + if (requested_dtype == NULL) { + return PyArray_ResultType(n, arrays, 0, NULL); + } + + PyArray_DTypeMeta *common_dtype; + PyArray_Descr *result = NULL; + if (PyArray_ExtractDTypeAndDescriptor( + requested_dtype, &result, &common_dtype) < 0) { + return NULL; + } + if (result != NULL) { + if (result->subarray != NULL) { + PyErr_Format(PyExc_TypeError, + "The dtype `%R` is not a valid dtype for concatenation " + "since it is a subarray dtype (the subarray dimensions " + "would be added as array dimensions).", result); + Py_DECREF(result); + return NULL; + } + goto finish; + } + assert(n > 0); /* concatenate requires at least one array input. */ + PyArray_Descr *descr = PyArray_DESCR(arrays[0]); + result = PyArray_CastDescrToDType(descr, common_dtype); + if (result == NULL || n == 1) { + goto finish; + } + /* + * This could short-cut a bit, calling `common_instance` directly and/or + * returning the `default_descr()` directly. Avoiding that (for now) as + * it would duplicate code from `PyArray_PromoteTypes`. + */ + for (npy_intp i = 1; i < n; i++) { + descr = PyArray_DESCR(arrays[i]); + PyArray_Descr *curr = PyArray_CastDescrToDType(descr, common_dtype); + if (curr == NULL) { + Py_SETREF(result, NULL); + goto finish; + } + Py_SETREF(result, PyArray_PromoteTypes(result, curr)); + Py_DECREF(curr); + if (result == NULL) { + goto finish; + } + } + + finish: + Py_DECREF(common_dtype); + return result; +} + + /** * This function defines the common DType operator. * diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index cc1930f77..97006b952 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -49,6 +49,10 @@ npy_set_invalid_cast_error( NPY_NO_EXPORT PyArray_Descr * PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType); +NPY_NO_EXPORT PyArray_Descr * +PyArray_FindConcatenationDescriptor( + npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype); + NPY_NO_EXPORT int PyArray_AddCastingImplmentation(PyBoundArrayMethodObject *meth); diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index af5949e73..cc747d862 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -448,17 +448,10 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, /* Get the priority subtype for the array */ PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays); - - if (dtype == NULL) { - /* Get the resulting dtype from combining all the arrays */ - dtype = (PyArray_Descr *)PyArray_ResultType( - narrays, arrays, 0, NULL); - if (dtype == NULL) { - return NULL; - } - } - else { - Py_INCREF(dtype); + PyArray_Descr *descr = PyArray_FindConcatenationDescriptor( + narrays, arrays, (PyObject *)dtype); + if (descr == NULL) { + return NULL; } /* @@ -467,7 +460,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, * resolution rules matching that of the NpyIter. */ PyArray_CreateMultiSortedStridePerm(narrays, arrays, ndim, strideperm); - s = dtype->elsize; + s = descr->elsize; for (idim = ndim-1; idim >= 0; --idim) { int iperm = strideperm[idim]; strides[iperm] = s; @@ -475,17 +468,13 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, } /* Allocate the array for the result. This steals the 'dtype' reference. */ - ret = (PyArrayObject *)PyArray_NewFromDescr(subtype, - dtype, - ndim, - shape, - strides, - NULL, - 0, - NULL); + ret = (PyArrayObject *)PyArray_NewFromDescr_int( + subtype, descr, ndim, shape, strides, NULL, 0, NULL, + NULL, 0, 1); if (ret == NULL) { return NULL; } + assert(PyArray_DESCR(ret) == descr); } /* @@ -575,32 +564,22 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays, /* Get the priority subtype for the array */ PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays); - if (dtype == NULL) { - /* Get the resulting dtype from combining all the arrays */ - dtype = (PyArray_Descr *)PyArray_ResultType( - narrays, arrays, 0, NULL); - if (dtype == NULL) { - return NULL; - } - } - else { - Py_INCREF(dtype); + PyArray_Descr *descr = PyArray_FindConcatenationDescriptor( + narrays, arrays, (PyObject *)dtype); + if (descr == NULL) { + return NULL; } - stride = dtype->elsize; + stride = descr->elsize; /* Allocate the array for the result. This steals the 'dtype' reference. */ - ret = (PyArrayObject *)PyArray_NewFromDescr(subtype, - dtype, - 1, - &shape, - &stride, - NULL, - 0, - NULL); + ret = (PyArrayObject *)PyArray_NewFromDescr_int( + subtype, descr, 1, &shape, &stride, NULL, 0, NULL, + NULL, 0, 1); if (ret == NULL) { return NULL; } + assert(PyArray_DESCR(ret) == descr); } /* diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index 4e56ace90..9922c9173 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -343,7 +343,7 @@ class TestConcatenate: concatenate((a, b), out=np.empty(4)) @pytest.mark.parametrize("axis", [None, 0]) - @pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"]) + @pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8", "S4"]) @pytest.mark.parametrize("casting", ['no', 'equiv', 'safe', 'same_kind', 'unsafe']) def test_out_and_dtype(self, axis, out_dtype, casting): @@ -369,6 +369,32 @@ class TestConcatenate: with assert_raises(TypeError): concatenate(to_concat, out=out, dtype=out_dtype, axis=axis) + @pytest.mark.parametrize("axis", [None, 0]) + @pytest.mark.parametrize("string_dt", ["S", "U", "S0", "U0"]) + @pytest.mark.parametrize("arrs", + [([0.],), ([0.], [1]), ([0], ["string"], [1.])]) + def test_dtype_with_promotion(self, arrs, string_dt, axis): + # Note that U0 and S0 should be deprecated eventually and changed to + # actually give the empty string result (together with `np.array`) + res = np.concatenate(arrs, axis=axis, dtype=string_dt, casting="unsafe") + assert res.dtype == np.promote_types("d", string_dt) + + @pytest.mark.parametrize("axis", [None, 0]) + def test_string_dtype_does_not_inspect(self, axis): + # The error here currently depends on NPY_USE_NEW_CASTINGIMPL as + # the new version rejects using the "default string length" of 64. + # The new behaviour is better, `np.array()` and `arr.astype()` would + # have to be used instead. (currently only raises due to unsafe cast) + with pytest.raises((ValueError, TypeError)): + np.concatenate(([None], [1]), dtype="S", axis=axis) + with pytest.raises((ValueError, TypeError)): + np.concatenate(([None], [1]), dtype="U", axis=axis) + + @pytest.mark.parametrize("axis", [None, 0]) + def test_subarray_error(self, axis): + with pytest.raises(TypeError, match=".*subarray dtype"): + np.concatenate(([1], [1]), dtype="(2,)i", axis=axis) + def test_stack(): # non-iterable input |
