summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2023-05-16 09:28:34 -0600
committerGitHub <noreply@github.com>2023-05-16 09:28:34 -0600
commit6a4abb065afa5cd9967b92476cd8a245edf3dfcd (patch)
tree7a16f842650373e869fccdf2299259bd27a989cd
parenta4a951d2ba256f6a391ae6dca30bee2bb491a59f (diff)
parent8b7f69ceae5cd99592f79121a1bd7b014af4833c (diff)
downloadnumpy-6a4abb065afa5cd9967b92476cd8a245edf3dfcd.tar.gz
Merge pull request #23659 from seberg/issue-23029
ENH: Restore TypeError cleanup in array function dispatching
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.c85
-rw-r--r--numpy/core/tests/test_overrides.py19
2 files changed, 102 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c
index 04768504e..3c55e2164 100644
--- a/numpy/core/src/multiarray/arrayfunction_override.c
+++ b/numpy/core/src/multiarray/arrayfunction_override.c
@@ -419,6 +419,9 @@ typedef struct {
PyObject *dict;
PyObject *relevant_arg_func;
PyObject *default_impl;
+ /* The following fields are used to clean up TypeError messages only: */
+ PyObject *dispatcher_name;
+ PyObject *public_name;
} PyArray_ArrayFunctionDispatcherObject;
@@ -428,10 +431,72 @@ dispatcher_dealloc(PyArray_ArrayFunctionDispatcherObject *self)
Py_CLEAR(self->relevant_arg_func);
Py_CLEAR(self->default_impl);
Py_CLEAR(self->dict);
+ Py_CLEAR(self->dispatcher_name);
+ Py_CLEAR(self->public_name);
PyObject_FREE(self);
}
+static void
+fix_name_if_typeerror(PyArray_ArrayFunctionDispatcherObject *self)
+{
+ if (!PyErr_ExceptionMatches(PyExc_TypeError)) {
+ return;
+ }
+
+ PyObject *exc, *val, *tb, *message;
+ PyErr_Fetch(&exc, &val, &tb);
+
+ if (!PyUnicode_CheckExact(val)) {
+ /*
+ * We expect the error to be unnormalized, but maybe it isn't always
+ * the case, so normalize and fetch args[0] if it isn't a string.
+ */
+ PyErr_NormalizeException(&exc, &val, &tb);
+
+ PyObject *args = PyObject_GetAttrString(val, "args");
+ if (args == NULL || !PyTuple_CheckExact(args)
+ || PyTuple_GET_SIZE(args) != 1) {
+ Py_XDECREF(args);
+ goto restore_error;
+ }
+ message = PyTuple_GET_ITEM(args, 0);
+ Py_INCREF(message);
+ Py_DECREF(args);
+ if (!PyUnicode_CheckExact(message)) {
+ Py_DECREF(message);
+ goto restore_error;
+ }
+ }
+ else {
+ Py_INCREF(val);
+ message = val;
+ }
+
+ Py_ssize_t cmp = PyUnicode_Tailmatch(
+ message, self->dispatcher_name, 0, -1, -1);
+ if (cmp <= 0) {
+ Py_DECREF(message);
+ goto restore_error;
+ }
+ Py_SETREF(message, PyUnicode_Replace(
+ message, self->dispatcher_name, self->public_name, 1));
+ if (message == NULL) {
+ goto restore_error;
+ }
+ PyErr_SetObject(PyExc_TypeError, message);
+ Py_DECREF(exc);
+ Py_XDECREF(val);
+ Py_XDECREF(tb);
+ Py_DECREF(message);
+ return;
+
+ restore_error:
+ /* replacement not successful, so restore original error */
+ PyErr_Restore(exc, val, tb);
+}
+
+
static PyObject *
dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
@@ -458,6 +523,7 @@ dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self,
relevant_args = PyObject_Vectorcall(
self->relevant_arg_func, args, len_args, kwnames);
if (relevant_args == NULL) {
+ fix_name_if_typeerror(self);
return NULL;
}
Py_SETREF(relevant_args, PySequence_Fast(relevant_args,
@@ -600,14 +666,31 @@ dispatcher_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwargs)
}
self->vectorcall = (vectorcallfunc)dispatcher_vectorcall;
+ Py_INCREF(self->default_impl);
+ self->dict = NULL;
+ self->dispatcher_name = NULL;
+ self->public_name = NULL;
+
if (self->relevant_arg_func == Py_None) {
/* NULL in the relevant arg function means we use `like=` */
Py_CLEAR(self->relevant_arg_func);
}
else {
+ /* Fetch names to clean up TypeErrors (show actual name) */
Py_INCREF(self->relevant_arg_func);
+ self->dispatcher_name = PyObject_GetAttrString(
+ self->relevant_arg_func, "__qualname__");
+ if (self->dispatcher_name == NULL) {
+ Py_DECREF(self);
+ return NULL;
+ }
+ self->public_name = PyObject_GetAttrString(
+ self->default_impl, "__qualname__");
+ if (self->public_name == NULL) {
+ Py_DECREF(self);
+ return NULL;
+ }
}
- Py_INCREF(self->default_impl);
/* Need to be like a Python function that has arbitrary attributes */
self->dict = PyDict_New();
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 25f551f6f..5924358ea 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -359,6 +359,17 @@ class TestArrayFunctionImplementation:
TypeError, "no implementation found for 'my.func'"):
func(MyArray())
+ @pytest.mark.parametrize("name", ["concatenate", "mean", "asarray"])
+ def test_signature_error_message_simple(self, name):
+ func = getattr(np, name)
+ try:
+ # all of these functions need an argument:
+ func()
+ except TypeError as e:
+ exc = e
+
+ assert exc.args[0].startswith(f"{name}()")
+
def test_signature_error_message(self):
# The lambda function will be named "<lambda>", but the TypeError
# should show the name as "func"
@@ -370,7 +381,7 @@ class TestArrayFunctionImplementation:
pass
try:
- func(bad_arg=3)
+ func._implementation(bad_arg=3)
except TypeError as e:
expected_exception = e
@@ -378,6 +389,12 @@ class TestArrayFunctionImplementation:
func(bad_arg=3)
raise AssertionError("must fail")
except TypeError as exc:
+ if exc.args[0].startswith("_dispatcher"):
+ # We replace the qualname currently, but it used `__name__`
+ # (relevant functions have the same name and qualname anyway)
+ pytest.skip("Python version is not using __qualname__ for "
+ "TypeError formatting.")
+
assert exc.args == expected_exception.args
@pytest.mark.parametrize("value", [234, "this func is not replaced"])