summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-08-21 17:04:41 -0500
committerSebastian Berg <sebastian@sipsolutions.net>2020-09-02 12:53:32 -0500
commit76253895fbfbd2b6f5a5d4a5d2c6b96ff2dc5a0c (patch)
tree4484c6bb263123657280be8636c2aa78526db38b
parentd9075b77586e0c7b536d5ec684bfd93c5bcd9439 (diff)
downloadnumpy-76253895fbfbd2b6f5a5d4a5d2c6b96ff2dc5a0c.tar.gz
API,MAINT: Rewrite promotion using common DType and common instance
This defines `common_dtype` and `common_instance` (only for parametric DTypes), and uses them to implement the `PyArray_CommonDType` operation. `PyArray_CommonDType()` together with the `common_instance` method then define the existing PromoteTypes. This does not (yet) affect "value based promotion" as defined by `PyArray_ResultType()`. We also require the step of casting to the common DType to define this type of example: ``` np.promote_types("S1", "i8") == np.dtype('S21') ``` This steps requires finding the string length corresponding to the integer (21 characters). This is here handled by the `PyArray_CastDescrToDType` function. However, that function still relies on `PyArray_AdaptFlexibleDType` and thus does not generalize to arbitrary DTypes. See NEP 42 (currently "Common DType Operations" section): https://numpy.org/neps/nep-0042-new-dtypes.html#common-dtype-operations
-rw-r--r--numpy/core/include/numpy/ndarraytypes.h6
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c370
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h3
-rw-r--r--numpy/core/src/multiarray/dtypemeta.c137
-rw-r--r--numpy/core/src/multiarray/dtypemeta.h16
-rw-r--r--numpy/core/src/multiarray/usertypes.c121
-rw-r--r--numpy/core/src/multiarray/usertypes.h4
7 files changed, 350 insertions, 307 deletions
diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h
index bbcf468c1..df480f96d 100644
--- a/numpy/core/include/numpy/ndarraytypes.h
+++ b/numpy/core/include/numpy/ndarraytypes.h
@@ -1839,6 +1839,10 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size,
PyArray_DTypeMeta *cls, PyTypeObject *obj);
typedef PyArray_Descr *(default_descr_function)(PyArray_DTypeMeta *cls);
+ typedef PyArray_DTypeMeta *(common_dtype_function)(
+ PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtyep2);
+ typedef PyArray_Descr *(common_instance_function)(
+ PyArray_Descr *dtype1, PyArray_Descr *dtyep2);
/*
* While NumPy DTypes would not need to be heap types the plan is to
@@ -1894,6 +1898,8 @@ typedef void (PyDataMem_EventHookFunc)(void *inp, void *outp, size_t size,
discover_descr_from_pyobject_function *discover_descr_from_pyobject;
is_known_scalar_type_function *is_known_scalar_type;
default_descr_function *default_descr;
+ common_dtype_function *common_dtype;
+ common_instance_function *common_instance;
};
#endif /* NPY_INTERNAL_BUILD */
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index 3d81edc17..f700bdc99 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -1080,6 +1080,50 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType)
}
+/**
+ * This function defines the common DType operator.
+ *
+ * Note that the common DType will not be "object" (unless one of the dtypes
+ * is object), even though object can technically represent all values
+ * correctly.
+ *
+ * TODO: Before exposure, we should review the return value (e.g. no error
+ * when no common DType is found).
+ *
+ * @param dtype1 DType class to find the common type for.
+ * @param dtype2 Second DType class.
+ * @return The common DType or NULL with an error set
+ */
+NPY_NO_EXPORT PyArray_DTypeMeta *
+PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2)
+{
+ if (dtype1 == dtype2) {
+ Py_INCREF(dtype1);
+ return dtype1;
+ }
+
+ PyArray_DTypeMeta *common_dtype;
+
+ common_dtype = dtype1->common_dtype(dtype1, dtype2);
+ if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) {
+ Py_DECREF(common_dtype);
+ common_dtype = dtype2->common_dtype(dtype2, dtype1);
+ }
+ if (common_dtype == NULL) {
+ return NULL;
+ }
+ if (common_dtype == (PyArray_DTypeMeta *)Py_NotImplemented) {
+ Py_DECREF(Py_NotImplemented);
+ PyErr_Format(PyExc_TypeError,
+ "The DTypes %S and %S do not have a common DType. "
+ "For example they cannot be stored in a single array unless "
+ "the dtype is `object`.", dtype1, dtype2);
+ return NULL;
+ }
+ return common_dtype;
+}
+
+
/*NUMPY_API
* Produces the smallest size and lowest kind type to which both
* input types can be cast.
@@ -1087,320 +1131,48 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType)
NPY_NO_EXPORT PyArray_Descr *
PyArray_PromoteTypes(PyArray_Descr *type1, PyArray_Descr *type2)
{
- int type_num1, type_num2, ret_type_num;
+ PyArray_DTypeMeta *common_dtype;
+ PyArray_Descr *res;
- /*
- * Fast path for identical dtypes.
- *
- * Non-native-byte-order types are converted to native ones below, so we
- * can't quit early.
- */
+ /* Fast path for identical inputs (NOTE: This path preserves metadata!) */
if (type1 == type2 && PyArray_ISNBO(type1->byteorder)) {
Py_INCREF(type1);
return type1;
}
- type_num1 = type1->type_num;
- type_num2 = type2->type_num;
-
- /* If they're built-in types, use the promotion table */
- if (type_num1 < NPY_NTYPES && type_num2 < NPY_NTYPES) {
- ret_type_num = _npy_type_promotion_table[type_num1][type_num2];
- /*
- * The table doesn't handle string/unicode/void/datetime/timedelta,
- * so check the result
- */
- if (ret_type_num >= 0) {
- return PyArray_DescrFromType(ret_type_num);
- }
- }
- /* If one or both are user defined, calculate it */
- else {
- int skind1 = NPY_NOSCALAR, skind2 = NPY_NOSCALAR, skind;
-
- if (PyArray_CanCastTo(type2, type1)) {
- /* Promoted types are always native byte order */
- return ensure_dtype_nbo(type1);
- }
- else if (PyArray_CanCastTo(type1, type2)) {
- /* Promoted types are always native byte order */
- return ensure_dtype_nbo(type2);
- }
-
- /* Convert the 'kind' char into a scalar kind */
- switch (type1->kind) {
- case 'b':
- skind1 = NPY_BOOL_SCALAR;
- break;
- case 'u':
- skind1 = NPY_INTPOS_SCALAR;
- break;
- case 'i':
- skind1 = NPY_INTNEG_SCALAR;
- break;
- case 'f':
- skind1 = NPY_FLOAT_SCALAR;
- break;
- case 'c':
- skind1 = NPY_COMPLEX_SCALAR;
- break;
- }
- switch (type2->kind) {
- case 'b':
- skind2 = NPY_BOOL_SCALAR;
- break;
- case 'u':
- skind2 = NPY_INTPOS_SCALAR;
- break;
- case 'i':
- skind2 = NPY_INTNEG_SCALAR;
- break;
- case 'f':
- skind2 = NPY_FLOAT_SCALAR;
- break;
- case 'c':
- skind2 = NPY_COMPLEX_SCALAR;
- break;
- }
-
- /* If both are scalars, there may be a promotion possible */
- if (skind1 != NPY_NOSCALAR && skind2 != NPY_NOSCALAR) {
-
- /* Start with the larger scalar kind */
- skind = (skind1 > skind2) ? skind1 : skind2;
- ret_type_num = _npy_smallest_type_of_kind_table[skind];
-
- for (;;) {
-
- /* If there is no larger type of this kind, try a larger kind */
- if (ret_type_num < 0) {
- ++skind;
- /* Use -1 to signal no promoted type found */
- if (skind < NPY_NSCALARKINDS) {
- ret_type_num = _npy_smallest_type_of_kind_table[skind];
- }
- else {
- break;
- }
- }
-
- /* If we found a type to which we can promote both, done! */
- if (PyArray_CanCastSafely(type_num1, ret_type_num) &&
- PyArray_CanCastSafely(type_num2, ret_type_num)) {
- return PyArray_DescrFromType(ret_type_num);
- }
-
- /* Try the next larger type of this kind */
- ret_type_num = _npy_next_larger_type_table[ret_type_num];
- }
-
- }
-
- PyErr_SetString(PyExc_TypeError,
- "invalid type promotion with custom data type");
+ common_dtype = PyArray_CommonDType(NPY_DTYPE(type1), NPY_DTYPE(type2));
+ if (common_dtype == NULL) {
return NULL;
}
- switch (type_num1) {
- /* BOOL can convert to anything except datetime/void */
- case NPY_BOOL:
- if (type_num2 == NPY_STRING || type_num2 == NPY_UNICODE) {
- int char_size = 1;
- if (type_num2 == NPY_UNICODE) {
- char_size = 4;
- }
- if (type2->elsize < 5 * char_size) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type2);
- ret = ensure_dtype_nbo(temp);
- ret->elsize = 5 * char_size;
- Py_DECREF(temp);
- return ret;
- }
- return ensure_dtype_nbo(type2);
- }
- else if (type_num2 != NPY_DATETIME && type_num2 != NPY_VOID) {
- return ensure_dtype_nbo(type2);
- }
- break;
- /* For strings and unicodes, take the larger size */
- case NPY_STRING:
- if (type_num2 == NPY_STRING) {
- if (type1->elsize > type2->elsize) {
- return ensure_dtype_nbo(type1);
- }
- else {
- return ensure_dtype_nbo(type2);
- }
- }
- else if (type_num2 == NPY_UNICODE) {
- if (type2->elsize >= type1->elsize * 4) {
- return ensure_dtype_nbo(type2);
- }
- else {
- PyArray_Descr *d = PyArray_DescrNewFromType(NPY_UNICODE);
- if (d == NULL) {
- return NULL;
- }
- d->elsize = type1->elsize * 4;
- return d;
- }
- }
- /* Allow NUMBER -> STRING */
- else if (PyTypeNum_ISNUMBER(type_num2)) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type1);
- PyDataType_MAKEUNSIZED(temp);
-
- temp = PyArray_AdaptFlexibleDType(type2, temp);
- if (temp == NULL) {
- return NULL;
- }
- if (temp->elsize > type1->elsize) {
- ret = ensure_dtype_nbo(temp);
- }
- else {
- ret = ensure_dtype_nbo(type1);
- }
- Py_DECREF(temp);
- return ret;
- }
- break;
- case NPY_UNICODE:
- if (type_num2 == NPY_UNICODE) {
- if (type1->elsize > type2->elsize) {
- return ensure_dtype_nbo(type1);
- }
- else {
- return ensure_dtype_nbo(type2);
- }
- }
- else if (type_num2 == NPY_STRING) {
- if (type1->elsize >= type2->elsize * 4) {
- return ensure_dtype_nbo(type1);
- }
- else {
- PyArray_Descr *d = PyArray_DescrNewFromType(NPY_UNICODE);
- if (d == NULL) {
- return NULL;
- }
- d->elsize = type2->elsize * 4;
- return d;
- }
- }
- /* Allow NUMBER -> UNICODE */
- else if (PyTypeNum_ISNUMBER(type_num2)) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type1);
- PyDataType_MAKEUNSIZED(temp);
- temp = PyArray_AdaptFlexibleDType(type2, temp);
- if (temp == NULL) {
- return NULL;
- }
- if (temp->elsize > type1->elsize) {
- ret = ensure_dtype_nbo(temp);
- }
- else {
- ret = ensure_dtype_nbo(type1);
- }
- Py_DECREF(temp);
- return ret;
- }
- break;
- case NPY_DATETIME:
- case NPY_TIMEDELTA:
- if (type_num2 == NPY_DATETIME || type_num2 == NPY_TIMEDELTA) {
- return datetime_type_promotion(type1, type2);
- }
- break;
+ if (!common_dtype->parametric) {
+ res = common_dtype->default_descr(common_dtype);
+ Py_DECREF(common_dtype);
+ return res;
}
- switch (type_num2) {
- /* BOOL can convert to almost anything */
- case NPY_BOOL:
- if (type_num2 == NPY_STRING || type_num2 == NPY_UNICODE) {
- int char_size = 1;
- if (type_num2 == NPY_UNICODE) {
- char_size = 4;
- }
- if (type2->elsize < 5 * char_size) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type2);
- ret = ensure_dtype_nbo(temp);
- ret->elsize = 5 * char_size;
- Py_DECREF(temp);
- return ret;
- }
- return ensure_dtype_nbo(type2);
- }
- else if (type_num1 != NPY_DATETIME && type_num1 != NPY_TIMEDELTA &&
- type_num1 != NPY_VOID) {
- return ensure_dtype_nbo(type1);
- }
- break;
- case NPY_STRING:
- /* Allow NUMBER -> STRING */
- if (PyTypeNum_ISNUMBER(type_num1)) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type2);
- PyDataType_MAKEUNSIZED(temp);
- temp = PyArray_AdaptFlexibleDType(type1, temp);
- if (temp == NULL) {
- return NULL;
- }
- if (temp->elsize > type2->elsize) {
- ret = ensure_dtype_nbo(temp);
- }
- else {
- ret = ensure_dtype_nbo(type2);
- }
- Py_DECREF(temp);
- return ret;
- }
- break;
- case NPY_UNICODE:
- /* Allow NUMBER -> UNICODE */
- if (PyTypeNum_ISNUMBER(type_num1)) {
- PyArray_Descr *ret = NULL;
- PyArray_Descr *temp = PyArray_DescrNew(type2);
- PyDataType_MAKEUNSIZED(temp);
- temp = PyArray_AdaptFlexibleDType(type1, temp);
- if (temp == NULL) {
- return NULL;
- }
- if (temp->elsize > type2->elsize) {
- ret = ensure_dtype_nbo(temp);
- }
- else {
- ret = ensure_dtype_nbo(type2);
- }
- Py_DECREF(temp);
- return ret;
- }
- break;
- case NPY_TIMEDELTA:
- if (PyTypeNum_ISSIGNED(type_num1)) {
- return ensure_dtype_nbo(type2);
- }
- break;
+ /* Cast the input types to the common DType if necessary */
+ type1 = PyArray_CastDescrToDType(type1, common_dtype);
+ if (type1 == NULL) {
+ Py_DECREF(common_dtype);
+ return NULL;
}
-
- /* For types equivalent up to endianness, can return either */
- if (PyArray_CanCastTypeTo(type1, type2, NPY_EQUIV_CASTING)) {
- return ensure_dtype_nbo(type1);
+ type2 = PyArray_CastDescrToDType(type2, common_dtype);
+ if (type2 == NULL) {
+ Py_DECREF(type1);
+ Py_DECREF(common_dtype);
+ return NULL;
}
- /* TODO: Also combine fields, subarrays, strings, etc */
-
/*
- printf("invalid type promotion: ");
- PyObject_Print(type1, stdout, 0);
- printf(" ");
- PyObject_Print(type2, stdout, 0);
- printf("\n");
- */
- PyErr_SetString(PyExc_TypeError, "invalid type promotion");
- return NULL;
+ * And find the common instance of the two inputs
+ * NOTE: Common instance preserves metadata (normally and of one input)
+ */
+ res = common_dtype->common_instance(type1, type2);
+ Py_DECREF(type1);
+ Py_DECREF(type2);
+ Py_DECREF(common_dtype);
+ return res;
}
/*
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index 507a72266..a2b36b497 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -10,6 +10,9 @@ PyArray_ObjectType(PyObject *op, int minimum_type);
NPY_NO_EXPORT PyArrayObject **
PyArray_ConvertToCommonType(PyObject *op, int *retn);
+NPY_NO_EXPORT PyArray_DTypeMeta *
+PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2);
+
NPY_NO_EXPORT int
PyArray_ValidType(int type);
diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c
index 6e5bf840e..92f50247e 100644
--- a/numpy/core/src/multiarray/dtypemeta.c
+++ b/numpy/core/src/multiarray/dtypemeta.c
@@ -15,6 +15,9 @@
#include "dtypemeta.h"
#include "_datetime.h"
#include "array_coercion.h"
+#include "scalartypes.h"
+#include "convert_datatype.h"
+#include "usertypes.h"
static void
@@ -216,6 +219,34 @@ flexible_default_descr(PyArray_DTypeMeta *cls)
}
+static PyArray_Descr *
+string_unicode_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2)
+{
+ if (descr1->elsize >= descr2->elsize) {
+ return ensure_dtype_nbo(descr1);
+ }
+ else {
+ return ensure_dtype_nbo(descr2);
+ }
+}
+
+
+static PyArray_Descr *
+void_common_instance(PyArray_Descr *descr1, PyArray_Descr *descr2)
+{
+ /*
+ * We currently do not support promotion of void types unless they
+ * are equivalent.
+ */
+ if (!PyArray_CanCastTypeTo(descr1, descr2, NPY_EQUIV_CASTING)) {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid type promotion with structured or void datatype(s).");
+ return NULL;
+ }
+ Py_INCREF(descr1);
+ return descr1;
+}
+
static int
python_builtins_are_known_scalar_types(
PyArray_DTypeMeta *NPY_UNUSED(cls), PyTypeObject *pytype)
@@ -289,6 +320,86 @@ string_known_scalar_types(
}
+/*
+ * The following set of functions define the common dtype operator for
+ * the builtin types.
+ */
+static PyArray_DTypeMeta *
+default_builtin_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
+{
+ assert(cls->type_num < NPY_NTYPES);
+ if (!other->legacy || other->type_num > cls->type_num) {
+ /* Let the more generic (larger type number) DType handle this */
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+
+ /*
+ * Note: The use of the promotion table should probably be revised at
+ * some point. It may be most useful to remove it entirely and then
+ * consider adding a fast path/cache `PyArray_CommonDType()` itself.
+ */
+ int common_num = _npy_type_promotion_table[cls->type_num][other->type_num];
+ if (common_num < 0) {
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+ return PyArray_DTypeFromTypeNum(common_num);
+}
+
+
+static PyArray_DTypeMeta *
+string_unicode_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
+{
+ assert(cls->type_num < NPY_NTYPES);
+ if (!other->legacy || other->type_num > cls->type_num ||
+ other->type_num == NPY_OBJECT) {
+ /* Let the more generic (larger type number) DType handle this */
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+ /*
+ * The builtin types are ordered by complexity (aside from object) here.
+ * Arguably, we should not consider numbers and strings "common", but
+ * we currently do.
+ */
+ Py_INCREF(cls);
+ return cls;
+}
+
+static PyArray_DTypeMeta *
+datetime_common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
+{
+ if (cls->type_num == NPY_DATETIME && other->type_num == NPY_TIMEDELTA) {
+ /*
+ * TODO: We actually currently do allow promotion here. This is
+ * currently relied on within `np.add(datetime, timedelta)`,
+ * while for concatenation the cast step will fail.
+ */
+ Py_INCREF(cls);
+ return cls;
+ }
+ return default_builtin_common_dtype(cls, other);
+}
+
+
+
+static PyArray_DTypeMeta *
+object_common_dtype(
+ PyArray_DTypeMeta *cls, PyArray_DTypeMeta *NPY_UNUSED(other))
+{
+ /*
+ * The object DType is special in that it can represent everything,
+ * including all potential user DTypes.
+ * One reason to defer (or error) here might be if the other DType
+ * does not support scalars so that e.g. `arr1d[0]` returns a 0-D array
+ * and `arr.astype(object)` would fail. But object casts are special.
+ */
+ Py_INCREF(cls);
+ return cls;
+}
+
+
/**
* This function takes a PyArray_Descr and replaces its base class with
* a newly created dtype subclass (DTypeMeta instances).
@@ -406,16 +517,28 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
dtype_class->f = descr->f;
dtype_class->kind = descr->kind;
- /* Strings and voids have (strange) logic around scalars. */
+ /* Set default functions (correct for most dtypes, override below) */
dtype_class->default_descr = nonparametric_default_descr;
-
+ dtype_class->discover_descr_from_pyobject = (
+ nonparametric_discover_descr_from_pyobject);
dtype_class->is_known_scalar_type = python_builtins_are_known_scalar_types;
+ dtype_class->common_dtype = default_builtin_common_dtype;
+ dtype_class->common_instance = NULL;
- if (PyTypeNum_ISDATETIME(descr->type_num)) {
+ if (PyTypeNum_ISUSERDEF(descr->type_num)) {
+ dtype_class->common_dtype = legacy_userdtype_common_dtype_function;
+ }
+ else if (descr->type_num == NPY_OBJECT) {
+ dtype_class->common_dtype = object_common_dtype;
+ }
+ else if (PyTypeNum_ISDATETIME(descr->type_num)) {
/* Datetimes are flexible, but were not considered previously */
dtype_class->parametric = NPY_TRUE;
+ dtype_class->default_descr = flexible_default_descr;
dtype_class->discover_descr_from_pyobject = (
discover_datetime_and_timedelta_from_pyobject);
+ dtype_class->common_dtype = datetime_common_dtype;
+ dtype_class->common_instance = datetime_type_promotion;
if (descr->type_num == NPY_DATETIME) {
dtype_class->is_known_scalar_type = datetime_known_scalar_types;
}
@@ -426,18 +549,16 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
if (descr->type_num == NPY_VOID) {
dtype_class->discover_descr_from_pyobject = (
void_discover_descr_from_pyobject);
+ dtype_class->common_instance = void_common_instance;
}
else {
dtype_class->is_known_scalar_type = string_known_scalar_types;
dtype_class->discover_descr_from_pyobject = (
string_discover_descr_from_pyobject);
+ dtype_class->common_dtype = string_unicode_common_dtype;
+ dtype_class->common_instance = string_unicode_common_instance;
}
}
- else {
- /* nonparametric case */
- dtype_class->discover_descr_from_pyobject = (
- nonparametric_discover_descr_from_pyobject);
- }
if (_PyArray_MapPyTypeToDType(dtype_class, descr->typeobj,
PyTypeNum_ISUSERDEF(dtype_class->type_num)) < 0) {
diff --git a/numpy/core/src/multiarray/dtypemeta.h b/numpy/core/src/multiarray/dtypemeta.h
index e0909a7eb..83cf7c07e 100644
--- a/numpy/core/src/multiarray/dtypemeta.h
+++ b/numpy/core/src/multiarray/dtypemeta.h
@@ -2,6 +2,22 @@
#define _NPY_DTYPEMETA_H
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
+/*
+ * This function will hopefully be phased out or replaced, but was convenient
+ * for incremental implementation of new DTypes based on DTypeMeta.
+ * (Error checking is not required for DescrFromType, assuming that the
+ * type is valid.)
+ */
+static NPY_INLINE PyArray_DTypeMeta *
+PyArray_DTypeFromTypeNum(int typenum)
+{
+ PyArray_Descr *descr = PyArray_DescrFromType(typenum);
+ PyArray_DTypeMeta *dtype = NPY_DTYPE(descr);
+ Py_INCREF(dtype);
+ Py_DECREF(descr);
+ return dtype;
+}
+
NPY_NO_EXPORT int
dtypemeta_wrap_legacy_descriptor(PyArray_Descr *dtypem);
diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c
index 6b6c6bd9d..265ec4be4 100644
--- a/numpy/core/src/multiarray/usertypes.c
+++ b/numpy/core/src/multiarray/usertypes.c
@@ -38,6 +38,7 @@ maintainer email: oliphant.travis@ieee.org
#include "usertypes.h"
#include "dtypemeta.h"
+#include "scalartypes.h"
NPY_NO_EXPORT PyArray_Descr **userdescrs=NULL;
@@ -347,3 +348,123 @@ PyArray_RegisterCanCast(PyArray_Descr *descr, int totype,
return _append_new(&descr->f->cancastscalarkindto[scalar], totype);
}
}
+
+
+/*
+ * Legacy user DTypes implemented the common DType operation
+ * (as used in type promotion/result_type, and e.g. the type for
+ * concatenation), by using "safe cast" logic.
+ *
+ * New DTypes do have this behaviour generally, but we use can-cast
+ * when legacy user dtypes are involved.
+ */
+NPY_NO_EXPORT PyArray_DTypeMeta *
+legacy_userdtype_common_dtype_function(
+ PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
+{
+ int skind1 = NPY_NOSCALAR, skind2 = NPY_NOSCALAR, skind;
+
+ if (!other->legacy) {
+ /* legacy DTypes can always defer to new style ones */
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+ /* Defer so that only one of the types handles the cast */
+ if (cls->type_num < other->type_num) {
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+ }
+
+ /* Check whether casting is possible from one type to the other */
+ if (PyArray_CanCastSafely(cls->type_num, other->type_num)) {
+ Py_INCREF(other);
+ return other;
+ }
+ if (PyArray_CanCastSafely(other->type_num, cls->type_num)) {
+ Py_INCREF(cls);
+ return cls;
+ }
+
+ /*
+ * The following code used to be part of PyArray_PromoteTypes().
+ * We can expect that this code is never used.
+ * In principle, it allows for promotion of two different user dtypes
+ * to a single NumPy dtype of the same "kind". In practice
+ * using the same `kind` as NumPy was never possible due to an
+ * simplification where `PyArray_EquivTypes(descr1, descr2)` will
+ * return True if both kind and element size match (e.g. bfloat16 and
+ * float16 would be equivalent).
+ * The option is also very obscure and not used in the examples.
+ */
+
+ /* Convert the 'kind' char into a scalar kind */
+ switch (cls->kind) {
+ case 'b':
+ skind1 = NPY_BOOL_SCALAR;
+ break;
+ case 'u':
+ skind1 = NPY_INTPOS_SCALAR;
+ break;
+ case 'i':
+ skind1 = NPY_INTNEG_SCALAR;
+ break;
+ case 'f':
+ skind1 = NPY_FLOAT_SCALAR;
+ break;
+ case 'c':
+ skind1 = NPY_COMPLEX_SCALAR;
+ break;
+ }
+ switch (other->kind) {
+ case 'b':
+ skind2 = NPY_BOOL_SCALAR;
+ break;
+ case 'u':
+ skind2 = NPY_INTPOS_SCALAR;
+ break;
+ case 'i':
+ skind2 = NPY_INTNEG_SCALAR;
+ break;
+ case 'f':
+ skind2 = NPY_FLOAT_SCALAR;
+ break;
+ case 'c':
+ skind2 = NPY_COMPLEX_SCALAR;
+ break;
+ }
+
+ /* If both are scalars, there may be a promotion possible */
+ if (skind1 != NPY_NOSCALAR && skind2 != NPY_NOSCALAR) {
+
+ /* Start with the larger scalar kind */
+ skind = (skind1 > skind2) ? skind1 : skind2;
+ int ret_type_num = _npy_smallest_type_of_kind_table[skind];
+
+ for (;;) {
+
+ /* If there is no larger type of this kind, try a larger kind */
+ if (ret_type_num < 0) {
+ ++skind;
+ /* Use -1 to signal no promoted type found */
+ if (skind < NPY_NSCALARKINDS) {
+ ret_type_num = _npy_smallest_type_of_kind_table[skind];
+ }
+ else {
+ break;
+ }
+ }
+
+ /* If we found a type to which we can promote both, done! */
+ if (PyArray_CanCastSafely(cls->type_num, ret_type_num) &&
+ PyArray_CanCastSafely(other->type_num, ret_type_num)) {
+ return PyArray_DTypeFromTypeNum(ret_type_num);
+ }
+
+ /* Try the next larger type of this kind */
+ ret_type_num = _npy_next_larger_type_table[ret_type_num];
+ }
+ }
+
+ Py_INCREF(Py_NotImplemented);
+ return (PyArray_DTypeMeta *)Py_NotImplemented;
+}
diff --git a/numpy/core/src/multiarray/usertypes.h b/numpy/core/src/multiarray/usertypes.h
index b3e386c5c..1b323d458 100644
--- a/numpy/core/src/multiarray/usertypes.h
+++ b/numpy/core/src/multiarray/usertypes.h
@@ -17,4 +17,8 @@ NPY_NO_EXPORT int
PyArray_RegisterCastFunc(PyArray_Descr *descr, int totype,
PyArray_VectorUnaryFunc *castfunc);
+NPY_NO_EXPORT PyArray_DTypeMeta *
+legacy_userdtype_common_dtype_function(
+ PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other);
+
#endif