diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-08-21 17:04:41 -0500 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2020-09-02 12:53:32 -0500 |
commit | 76253895fbfbd2b6f5a5d4a5d2c6b96ff2dc5a0c (patch) | |
tree | 4484c6bb263123657280be8636c2aa78526db38b | |
parent | d9075b77586e0c7b536d5ec684bfd93c5bcd9439 (diff) | |
download | numpy-76253895fbfbd2b6f5a5d4a5d2c6b96ff2dc5a0c.tar.gz |
API,MAINT: Rewrite promotion using common DType and common instance
This defines `common_dtype` and `common_instance` (only for parametric
DTypes), and uses them to implement the `PyArray_CommonDType` operation.
`PyArray_CommonDType()` together with the `common_instance` method
then define the existing PromoteTypes.
This does not (yet) affect "value based promotion" as defined by
`PyArray_ResultType()`. We also require the step of casting
to the common DType to define this type of example:
```
np.promote_types("S1", "i8") == np.dtype('S21')
```
This steps requires finding the string length corresponding to
the integer (21 characters). This is here handled by the
`PyArray_CastDescrToDType` function. However, that function
still relies on `PyArray_AdaptFlexibleDType` and thus does not
generalize to arbitrary DTypes.
See NEP 42 (currently "Common DType Operations" section):
https://numpy.org/neps/nep-0042-new-dtypes.html#common-dtype-operations
-rw-r--r-- | numpy/core/include/numpy/ndarraytypes.h | 6 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 370 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.h | 3 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.c | 137 | ||||
-rw-r--r-- | numpy/core/src/multiarray/dtypemeta.h | 16 | ||||
-rw-r--r-- | numpy/core/src/multiarray/usertypes.c | 121 | ||||
-rw-r--r-- | numpy/core/src/multiarray/usertypes.h | 4 |
7 files changed, 350 insertions, 307 deletions
diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h index bbcf468c1..df480f96d 100644 --- a/numpy/core/include/numpy/ndarraytypes.h +++ b/numpy/core/include/numpy/ndarraytypes.h @@ -1839,6 +1839,10 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size, PyArray_DTypeMeta *cls, PyTypeObject *obj); typedef PyArray_Descr *(default_descr_function)(PyArray_DTypeMeta *cls); + typedef PyArray_DTypeMeta *(common_dtype_function)( + PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtyep2); + typedef PyArray_Descr *(common_instance_function)( + PyArray_Descr *dtype1, PyArray_Descr *dtyep2); /* * While NumPy DTypes would not need to be heap types the plan is to @@ -1894,6 +1898,8 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size, discover_descr_from_pyobject_function *discover_descr_from_pyobject; is_known_scalar_type_function *is_known_scalar_type; default_descr_function *default_descr; + common_dtype_function *common_dtype; + common_instance_function *common_instance; }; #endif /* NPY_INTERNAL_BUILD */ diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 3d81edc17..f700bdc99 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -1080,6 +1080,50 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType) } +/** + * This function defines the common DType operator. + * + * Note that the common DType will not be "object" (unless one of the dtypes + * is object), even though object can technically represent all values + * correctly. + * + * TODO: Before exposure, we should review the return value (e.g. no error + * when no common DType is found). + * + * @param dtype1 DType class to find the common type for. + * @param dtype2 Second DType class. + * @return The common DType or NULL with an error set + */ +NPY_NO_EXPORT PyArray_DTypeMeta * +PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2) +{ + if (dtype1 == dtype2) { + Py_INCREF(dtype1); + return dtype1; + } + + PyArray_DTypeMeta *common_dtype; + + common_dtype = dtype1->common_dtype(dtype1, dtype2); + if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) { + Py_DECREF(common_dtype); + common_dtype = dtype2->common_dtype(dtype2, dtype1); + } + if (common_dtype == NULL) { + return NULL; + } + if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) { + Py_DECREF(Py_NotImplemented); + PyErr_Format(PyExc_TypeError, + "The DTypes %S and %S do not have a common DType. " + "For example they cannot be stored in a single array unless " + "the dtype is `object`.", dtype1, dtype2); + return NULL; + } + return common_dtype; +} + + /*NUMPY_API * Produces the smallest size and lowest kind type to which both * input types can be cast. @@ -1087,320 +1131,48 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType) NPY_NO_EXPORT PyArray_Descr * PyArray_PromoteTypes(PyArray_Descr *type1, PyArray_Descr *type2) { - int type_num1, type_num2, ret_type_num; + PyArray_DTypeMeta *common_dtype; + PyArray_Descr *res; - /* - * Fast path for identical dtypes. - * - * Non-native-byte-order types are converted to native ones below, so we - * can't quit early. - */ + /* Fast path for identical inputs (NOTE: This path preserves metadata!) */ if (type1 == type2 && PyArray_ISNBO(type1->byteorder)) { Py_INCREF(type1); return type1; } - type_num1 = type1->type_num; - type_num2 = type2->type_num; - - /* If they're built-in types, use the promotion table */ - if (type_num1 < NPY_NTYPES && type_num2 < NPY_NTYPES) { - ret_type_num = _npy_type_promotion_table[type_num1][type_num2]; - /* - * The table doesn't handle string/unicode/void/datetime/timedelta, - * so check the result - */ - if (ret_type_num >= 0) { - return PyArray_DescrFromType(ret_type_num); - } - } - /* If one or both are user defined, calculate it */ - else { - int skind1 = NPY_NOSCALAR, skind2 = NPY_NOSCALAR, skind; - - if (PyArray_CanCastTo(type2, type1)) { - /* Promoted types are always native byte order */ - return ensure_dtype_nbo(type1); - } - else if (PyArray_CanCastTo(type1, type2)) { - /* Promoted types are always native byte order */ - return ensure_dtype_nbo(type2); - } - - /* Convert the 'kind' char into a scalar kind */ - switch (type1->kind) { - case 'b': - skind1 = NPY_BOOL_SCALAR; - break; - case 'u': - skind1 = NPY_INTPOS_SCALAR; - break; - case 'i': - skind1 = NPY_INTNEG_SCALAR; - break; - case 'f': - skind1 = NPY_FLOAT_SCALAR; - break; - case 'c': - skind1 = NPY_COMPLEX_SCALAR; - break; - } - switch (type2->kind) { - case 'b': - skind2 = NPY_BOOL_SCALAR; - break; - case 'u': - skind2 = NPY_INTPOS_SCALAR; - break; - case 'i': - skind2 = NPY_INTNEG_SCALAR; - break; - case 'f': - skind2 = NPY_FLOAT_SCALAR; - break; - case 'c': - skind2 = NPY_COMPLEX_SCALAR; - break; - } - - /* If both are scalars, there may be a promotion possible */ - if (skind1 != NPY_NOSCALAR && skind2 != NPY_NOSCALAR) { - - /* Start with the larger scalar kind */ - skind = (skind1 > skind2) ? skind1 : skind2; - ret_type_num = _npy_smallest_type_of_kind_table[skind]; - - for (;;) { - - /* If there is no larger type of this kind, try a larger kind */ - if (ret_type_num < 0) { - ++skind; - /* Use -1 to signal no promoted type found */ - if (skind < NPY_NSCALARKINDS) { - ret_type_num = _npy_smallest_type_of_kind_table[skind]; - } - else { - break; - } - } - - /* If we found a type to which we can promote both, done! */ - if (PyArray_CanCastSafely(type_num1, ret_type_num) && - PyArray_CanCastSafely(type_num2, ret_type_num)) { - return PyArray_DescrFromType(ret_type_num); - } - - /* Try the next larger type of this kind */ - ret_type_num = _npy_next_larger_type_table[ret_type_num]; - } - - } - - PyErr_SetString(PyExc_TypeError, - "invalid type promotion with custom data type"); + common_dtype = PyArray_CommonDType(NPY_DTYPE(type1), NPY_DTYPE(type2)); + if (common_dtype == NULL) { return NULL; } - switch (type_num1) { - /* BOOL can convert to anything except datetime/void */ - case NPY_BOOL: - if (type_num2 == NPY_STRING || type_num2 == NPY_UNICODE) { - int char_size = 1; - if (type_num2 == NPY_UNICODE) { - char_size = 4; - } - if (type2->elsize < 5 * char_size) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type2); - ret = ensure_dtype_nbo(temp); - ret->elsize = 5 * char_size; - Py_DECREF(temp); - return ret; - } - return ensure_dtype_nbo(type2); - } - else if (type_num2 != NPY_DATETIME && type_num2 != NPY_VOID) { - return ensure_dtype_nbo(type2); - } - break; - /* For strings and unicodes, take the larger size */ - case NPY_STRING: - if (type_num2 == NPY_STRING) { - if (type1->elsize > type2->elsize) { - return ensure_dtype_nbo(type1); - } - else { - return ensure_dtype_nbo(type2); - } - } - else if (type_num2 == NPY_UNICODE) { - if (type2->elsize >= type1->elsize * 4) { - return ensure_dtype_nbo(type2); - } - else { - PyArray_Descr *d = PyArray_DescrNewFromType(NPY_UNICODE); - if (d == NULL) { - return NULL; - } - d->elsize = type1->elsize * 4; - return d; - } - } - /* Allow NUMBER -> STRING */ - else if (PyTypeNum_ISNUMBER(type_num2)) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type1); - PyDataType_MAKEUNSIZED(temp); - - temp = PyArray_AdaptFlexibleDType(type2, temp); - if (temp == NULL) { - return NULL; - } - if (temp->elsize > type1->elsize) { - ret = ensure_dtype_nbo(temp); - } - else { - ret = ensure_dtype_nbo(type1); - } - Py_DECREF(temp); - return ret; - } - break; - case NPY_UNICODE: - if (type_num2 == NPY_UNICODE) { - if (type1->elsize > type2->elsize) { - return ensure_dtype_nbo(type1); - } - else { - return ensure_dtype_nbo(type2); - } - } - else if (type_num2 == NPY_STRING) { - if (type1->elsize >= type2->elsize * 4) { - return ensure_dtype_nbo(type1); - } - else { - PyArray_Descr *d = PyArray_DescrNewFromType(NPY_UNICODE); - if (d == NULL) { - return NULL; - } - d->elsize = type2->elsize * 4; - return d; - } - } - /* Allow NUMBER -> UNICODE */ - else if (PyTypeNum_ISNUMBER(type_num2)) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type1); - PyDataType_MAKEUNSIZED(temp); - temp = PyArray_AdaptFlexibleDType(type2, temp); - if (temp == NULL) { - return NULL; - } - if (temp->elsize > type1->elsize) { - ret = ensure_dtype_nbo(temp); - } - else { - ret = ensure_dtype_nbo(type1); - } - Py_DECREF(temp); - return ret; - } - break; - case NPY_DATETIME: - case NPY_TIMEDELTA: - if (type_num2 == NPY_DATETIME || type_num2 == NPY_TIMEDELTA) { - return datetime_type_promotion(type1, type2); - } - break; + if (!common_dtype->parametric) { + res = common_dtype->default_descr(common_dtype); + Py_DECREF(common_dtype); + return res; } - switch (type_num2) { - /* BOOL can convert to almost anything */ - case NPY_BOOL: - if (type_num2 == NPY_STRING || type_num2 == NPY_UNICODE) { - int char_size = 1; - if (type_num2 == NPY_UNICODE) { - char_size = 4; - } - if (type2->elsize < 5 * char_size) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type2); - ret = ensure_dtype_nbo(temp); - ret->elsize = 5 * char_size; - Py_DECREF(temp); - return ret; - } - return ensure_dtype_nbo(type2); - } - else if (type_num1 != NPY_DATETIME && type_num1 != NPY_TIMEDELTA && - type_num1 != NPY_VOID) { - return ensure_dtype_nbo(type1); - } - break; - case NPY_STRING: - /* Allow NUMBER -> STRING */ - if (PyTypeNum_ISNUMBER(type_num1)) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type2); - PyDataType_MAKEUNSIZED(temp); - temp = PyArray_AdaptFlexibleDType(type1, temp); - if (temp == NULL) { - return NULL; - } - if (temp->elsize > type2->elsize) { - ret = ensure_dtype_nbo(temp); - } - else { - ret = ensure_dtype_nbo(type2); - } - Py_DECREF(temp); - return ret; - } - break; - case NPY_UNICODE: - /* Allow NUMBER -> UNICODE */ - if (PyTypeNum_ISNUMBER(type_num1)) { - PyArray_Descr *ret = NULL; - PyArray_Descr *temp = PyArray_DescrNew(type2); - PyDataType_MAKEUNSIZED(temp); - temp = PyArray_AdaptFlexibleDType(type1, temp); - if (temp == NULL) { - return NULL; - } - if (temp->elsize > type2->elsize) { - ret = ensure_dtype_nbo(temp); - } - else { - ret = ensure_dtype_nbo(type2); - } - Py_DECREF(temp); - return ret; - } - break; - case NPY_TIMEDELTA: - if (PyTypeNum_ISSIGNED(type_num1)) { - return ensure_dtype_nbo(type2); - } - break; + /* Cast the input types to the common DType if necessary */ + type1 = PyArray_CastDescrToDType(type1, common_dtype); + if (type1 == NULL) { + Py_DECREF(common_dtype); + return NULL; } - - /* For types equivalent up to endianness, can return either */ - if (PyArray_CanCastTypeTo(type1, type2, NPY_EQUIV_CASTING)) { - return ensure_dtype_nbo(type1); + type2 = PyArray_CastDescrToDType(type2, common_dtype); + if (type2 == NULL) { + Py_DECREF(type1); + Py_DECREF(common_dtype); + return NULL; } - /* TODO: Also combine fields, subarrays, strings, etc */ - /* - printf("invalid type promotion: "); - PyObject_Print(type1, stdout, 0); - printf(" "); - PyObject_Print(type2, stdout, 0); - printf("\n"); - */ - PyErr_SetString(PyExc_TypeError, "invalid type promotion"); - return NULL; + * And find the common instance of the two inputs + * NOTE: Common instance preserves metadata (normally and of one input) + */ + res = common_dtype->common_instance(type1, type2); + Py_DECREF(type1); + Py_DECREF(type2); + Py_DECREF(common_dtype); + return res; } /* diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index 507a72266..a2b36b497 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -10,6 +10,9 @@ PyArray_ObjectType(PyObject *op, int minimum_type); NPY_NO_EXPORT PyArrayObject ** PyArray_ConvertToCommonType(PyObject *op, int *retn); +NPY_NO_EXPORT PyArray_DTypeMeta * +PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2); + NPY_NO_EXPORT int PyArray_ValidType(int type); diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index 6e5bf840e..92f50247e 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -15,6 +15,9 @@ #include "dtypemeta.h" #include "_datetime.h" #include "array_coercion.h" +#include "scalartypes.h" +#include "convert_datatype.h" +#include "usertypes.h" static void @@ -216,6 +219,34 @@ flexible_default_descr(PyArray_DTypeMeta *cls) } +static PyArray_Descr * +string_unicode_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2) +{ + if (descr1->elsize >= descr2->elsize) { + return ensure_dtype_nbo(descr1); + } + else { + return ensure_dtype_nbo(descr2); + } +} + + +static PyArray_Descr * +void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2) +{ + /* + * We currently do not support promotion of void types unless they + * are equivalent. + */ + if (!PyArray_CanCastTypeTo(descr1, descr2, NPY_EQUIV_CASTING)) { + PyErr_SetString(PyExc_TypeError, + "invalid type promotion with structured or void datatype(s)."); + return NULL; + } + Py_INCREF(descr1); + return descr1; +} + static int python_builtins_are_known_scalar_types( PyArray_DTypeMeta *NPY_UNUSED(cls), PyTypeObject *pytype) @@ -289,6 +320,86 @@ string_known_scalar_types( } +/* + * The following set of functions define the common dtype operator for + * the builtin types. + */ +static PyArray_DTypeMeta * +default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) +{ + assert(cls->type_num < NPY_NTYPES); + if (!other->legacy || other->type_num > cls->type_num) { + /* Let the more generic (larger type number) DType handle this */ + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + + /* + * Note: The use of the promotion table should probably be revised at + * some point. It may be most useful to remove it entirely and then + * consider adding a fast path/cache `PyArray_CommonDType()` itself. + */ + int common_num = _npy_type_promotion_table[cls->type_num][other->type_num]; + if (common_num < 0) { + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + return PyArray_DTypeFromTypeNum(common_num); +} + + +static PyArray_DTypeMeta * +string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) +{ + assert(cls->type_num < NPY_NTYPES); + if (!other->legacy || other->type_num > cls->type_num || + other->type_num == NPY_OBJECT) { + /* Let the more generic (larger type number) DType handle this */ + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + /* + * The builtin types are ordered by complexity (aside from object) here. + * Arguably, we should not consider numbers and strings "common", but + * we currently do. + */ + Py_INCREF(cls); + return cls; +} + +static PyArray_DTypeMeta * +datetime_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) +{ + if (cls->type_num == NPY_DATETIME && other->type_num == NPY_TIMEDELTA) { + /* + * TODO: We actually currently do allow promotion here. This is + * currently relied on within `np.add(datetime, timedelta)`, + * while for concatenation the cast step will fail. + */ + Py_INCREF(cls); + return cls; + } + return default_builtin_common_dtype(cls, other); +} + + + +static PyArray_DTypeMeta * +object_common_dtype( + PyArray_DTypeMeta *cls, PyArray_DTypeMeta *NPY_UNUSED(other)) +{ + /* + * The object DType is special in that it can represent everything, + * including all potential user DTypes. + * One reason to defer (or error) here might be if the other DType + * does not support scalars so that e.g. `arr1d[0]` returns a 0-D array + * and `arr.astype(object)` would fail. But object casts are special. + */ + Py_INCREF(cls); + return cls; +} + + /** * This function takes a PyArray_Descr and replaces its base class with * a newly created dtype subclass (DTypeMeta instances). @@ -406,16 +517,28 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) dtype_class->f = descr->f; dtype_class->kind = descr->kind; - /* Strings and voids have (strange) logic around scalars. */ + /* Set default functions (correct for most dtypes, override below) */ dtype_class->default_descr = nonparametric_default_descr; - + dtype_class->discover_descr_from_pyobject = ( + nonparametric_discover_descr_from_pyobject); dtype_class->is_known_scalar_type = python_builtins_are_known_scalar_types; + dtype_class->common_dtype = default_builtin_common_dtype; + dtype_class->common_instance = NULL; - if (PyTypeNum_ISDATETIME(descr->type_num)) { + if (PyTypeNum_ISUSERDEF(descr->type_num)) { + dtype_class->common_dtype = legacy_userdtype_common_dtype_function; + } + else if (descr->type_num == NPY_OBJECT) { + dtype_class->common_dtype = object_common_dtype; + } + else if (PyTypeNum_ISDATETIME(descr->type_num)) { /* Datetimes are flexible, but were not considered previously */ dtype_class->parametric = NPY_TRUE; + dtype_class->default_descr = flexible_default_descr; dtype_class->discover_descr_from_pyobject = ( discover_datetime_and_timedelta_from_pyobject); + dtype_class->common_dtype = datetime_common_dtype; + dtype_class->common_instance = datetime_type_promotion; if (descr->type_num == NPY_DATETIME) { dtype_class->is_known_scalar_type = datetime_known_scalar_types; } @@ -426,18 +549,16 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) if (descr->type_num == NPY_VOID) { dtype_class->discover_descr_from_pyobject = ( void_discover_descr_from_pyobject); + dtype_class->common_instance = void_common_instance; } else { dtype_class->is_known_scalar_type = string_known_scalar_types; dtype_class->discover_descr_from_pyobject = ( string_discover_descr_from_pyobject); + dtype_class->common_dtype = string_unicode_common_dtype; + dtype_class->common_instance = string_unicode_common_instance; } } - else { - /* nonparametric case */ - dtype_class->discover_descr_from_pyobject = ( - nonparametric_discover_descr_from_pyobject); - } if (_PyArray_MapPyTypeToDType(dtype_class, descr->typeobj, PyTypeNum_ISUSERDEF(dtype_class->type_num)) < 0) { diff --git a/numpy/core/src/multiarray/dtypemeta.h b/numpy/core/src/multiarray/dtypemeta.h index e0909a7eb..83cf7c07e 100644 --- a/numpy/core/src/multiarray/dtypemeta.h +++ b/numpy/core/src/multiarray/dtypemeta.h @@ -2,6 +2,22 @@ #define _NPY_DTYPEMETA_H #define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr)) +/* + * This function will hopefully be phased out or replaced, but was convenient + * for incremental implementation of new DTypes based on DTypeMeta. + * (Error checking is not required for DescrFromType, assuming that the + * type is valid.) + */ +static NPY_INLINE PyArray_DTypeMeta * +PyArray_DTypeFromTypeNum(int typenum) +{ + PyArray_Descr *descr = PyArray_DescrFromType(typenum); + PyArray_DTypeMeta *dtype = NPY_DTYPE(descr); + Py_INCREF(dtype); + Py_DECREF(descr); + return dtype; +} + NPY_NO_EXPORT int dtypemeta_wrap_legacy_descriptor(PyArray_Descr *dtypem); diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c index 6b6c6bd9d..265ec4be4 100644 --- a/numpy/core/src/multiarray/usertypes.c +++ b/numpy/core/src/multiarray/usertypes.c @@ -38,6 +38,7 @@ maintainer email: oliphant.travis@ieee.org #include "usertypes.h" #include "dtypemeta.h" +#include "scalartypes.h" NPY_NO_EXPORT PyArray_Descr **userdescrs=NULL; @@ -347,3 +348,123 @@ PyArray_RegisterCanCast(PyArray_Descr *descr, int totype, return _append_new(&descr->f->cancastscalarkindto[scalar], totype); } } + + +/* + * Legacy user DTypes implemented the common DType operation + * (as used in type promotion/result_type, and e.g. the type for + * concatenation), by using "safe cast" logic. + * + * New DTypes do have this behaviour generally, but we use can-cast + * when legacy user dtypes are involved. + */ +NPY_NO_EXPORT PyArray_DTypeMeta * +legacy_userdtype_common_dtype_function( + PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other) +{ + int skind1 = NPY_NOSCALAR, skind2 = NPY_NOSCALAR, skind; + + if (!other->legacy) { + /* legacy DTypes can always defer to new style ones */ + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + /* Defer so that only one of the types handles the cast */ + if (cls->type_num < other->type_num) { + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; + } + + /* Check whether casting is possible from one type to the other */ + if (PyArray_CanCastSafely(cls->type_num, other->type_num)) { + Py_INCREF(other); + return other; + } + if (PyArray_CanCastSafely(other->type_num, cls->type_num)) { + Py_INCREF(cls); + return cls; + } + + /* + * The following code used to be part of PyArray_PromoteTypes(). + * We can expect that this code is never used. + * In principle, it allows for promotion of two different user dtypes + * to a single NumPy dtype of the same "kind". In practice + * using the same `kind` as NumPy was never possible due to an + * simplification where `PyArray_EquivTypes(descr1, descr2)` will + * return True if both kind and element size match (e.g. bfloat16 and + * float16 would be equivalent). + * The option is also very obscure and not used in the examples. + */ + + /* Convert the 'kind' char into a scalar kind */ + switch (cls->kind) { + case 'b': + skind1 = NPY_BOOL_SCALAR; + break; + case 'u': + skind1 = NPY_INTPOS_SCALAR; + break; + case 'i': + skind1 = NPY_INTNEG_SCALAR; + break; + case 'f': + skind1 = NPY_FLOAT_SCALAR; + break; + case 'c': + skind1 = NPY_COMPLEX_SCALAR; + break; + } + switch (other->kind) { + case 'b': + skind2 = NPY_BOOL_SCALAR; + break; + case 'u': + skind2 = NPY_INTPOS_SCALAR; + break; + case 'i': + skind2 = NPY_INTNEG_SCALAR; + break; + case 'f': + skind2 = NPY_FLOAT_SCALAR; + break; + case 'c': + skind2 = NPY_COMPLEX_SCALAR; + break; + } + + /* If both are scalars, there may be a promotion possible */ + if (skind1 != NPY_NOSCALAR && skind2 != NPY_NOSCALAR) { + + /* Start with the larger scalar kind */ + skind = (skind1 > skind2) ? skind1 : skind2; + int ret_type_num = _npy_smallest_type_of_kind_table[skind]; + + for (;;) { + + /* If there is no larger type of this kind, try a larger kind */ + if (ret_type_num < 0) { + ++skind; + /* Use -1 to signal no promoted type found */ + if (skind < NPY_NSCALARKINDS) { + ret_type_num = _npy_smallest_type_of_kind_table[skind]; + } + else { + break; + } + } + + /* If we found a type to which we can promote both, done! */ + if (PyArray_CanCastSafely(cls->type_num, ret_type_num) && + PyArray_CanCastSafely(other->type_num, ret_type_num)) { + return PyArray_DTypeFromTypeNum(ret_type_num); + } + + /* Try the next larger type of this kind */ + ret_type_num = _npy_next_larger_type_table[ret_type_num]; + } + } + + Py_INCREF(Py_NotImplemented); + return (PyArray_DTypeMeta *)Py_NotImplemented; +} diff --git a/numpy/core/src/multiarray/usertypes.h b/numpy/core/src/multiarray/usertypes.h index b3e386c5c..1b323d458 100644 --- a/numpy/core/src/multiarray/usertypes.h +++ b/numpy/core/src/multiarray/usertypes.h @@ -17,4 +17,8 @@ NPY_NO_EXPORT int PyArray_RegisterCastFunc(PyArray_Descr *descr, int totype, PyArray_VectorUnaryFunc *castfunc); +NPY_NO_EXPORT PyArray_DTypeMeta * +legacy_userdtype_common_dtype_function( + PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other); + #endif |