diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2023-04-25 12:22:53 +0200 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-04-25 12:39:33 +0200 |
commit | 61610e74340a4a22f2782274600ae34bd882b929 (patch) | |
tree | 5ed0568467fcdb3ca534b8d58432ec6d6e579bac | |
parent | 6f3e1f458e04d13bdd56cff5669f9fd96a25fb66 (diff) | |
download | numpy-61610e74340a4a22f2782274600ae34bd882b929.tar.gz |
ENH: Restore TypeError cleanup in array function dispatching
When the dispathcer raises a TypeError and it starts with the dispatchers
name (or actually __qualname__ not that it normally matters), then it is
nicer for users if we just raise a new error with the public symbol name.
Python does not seem to normalize exception and goes down the unicode path,
but I assume that e.g. PyPy may not do that. And there might be other
weirder reason why we go down the full path. I have manually tested it
by forcing Normalization.
Closes gh-23029
-rw-r--r-- | numpy/core/src/multiarray/arrayfunction_override.c | 82 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 13 |
2 files changed, 93 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c index 04768504e..63d109ecb 100644 --- a/numpy/core/src/multiarray/arrayfunction_override.c +++ b/numpy/core/src/multiarray/arrayfunction_override.c @@ -419,6 +419,9 @@ typedef struct { PyObject *dict; PyObject *relevant_arg_func; PyObject *default_impl; + /* The following fields are used to clean up TypeError messages only: */ + PyObject *dispatcher_name; + PyObject *public_name; } PyArray_ArrayFunctionDispatcherObject; @@ -428,10 +431,69 @@ dispatcher_dealloc(PyArray_ArrayFunctionDispatcherObject *self) Py_CLEAR(self->relevant_arg_func); Py_CLEAR(self->default_impl); Py_CLEAR(self->dict); + Py_CLEAR(self->dispatcher_name); + Py_CLEAR(self->public_name); PyObject_FREE(self); } +static void +fix_name_if_typeerror(PyArray_ArrayFunctionDispatcherObject *self) +{ + if (!PyErr_ExceptionMatches(PyExc_TypeError)) { + return; + } + + PyObject *exc, *val, *tb, *message; + PyErr_Fetch(&exc, &val, &tb); + + if (!PyUnicode_CheckExact(val)) { + /* + * We expect the error to be unnormalized, but maybe it isn't always + * the case, so normalize and fetch args[0] if it isn't a string. + */ + PyErr_NormalizeException(&exc, &val, &tb); + + PyObject *args = PyObject_GetAttrString(val, "args"); + if (args == NULL || !PyTuple_CheckExact(args) + || PyTuple_GET_SIZE(args) != 1) { + Py_XDECREF(args); + goto restore_error; + } + message = PyTuple_GET_ITEM(args, 0); + Py_INCREF(message); + Py_DECREF(args); + if (!PyUnicode_CheckExact(message)) { + Py_DECREF(message); + goto restore_error; + } + } + else { + Py_INCREF(val); + message = val; + } + + Py_ssize_t cmp = PyUnicode_Tailmatch( + message, self->dispatcher_name, 0, -1, 0); + if (cmp <= 0) { + Py_DECREF(message); + goto restore_error; + } + Py_SETREF(message, PyUnicode_Replace( + message, self->dispatcher_name, self->public_name, 1)); + if (message == NULL) { + goto restore_error; + } + PyErr_SetObject(PyExc_TypeError, message); + Py_DECREF(message); + return; + + restore_error: + /* replacement not successful, so restore original error */ + PyErr_Restore(exc, val, tb); +} + + static PyObject * dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) @@ -458,6 +520,7 @@ dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self, relevant_args = PyObject_Vectorcall( self->relevant_arg_func, args, len_args, kwnames); if (relevant_args == NULL) { + fix_name_if_typeerror(self); return NULL; } Py_SETREF(relevant_args, PySequence_Fast(relevant_args, @@ -600,14 +663,31 @@ dispatcher_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwargs) } self->vectorcall = (vectorcallfunc)dispatcher_vectorcall; + Py_INCREF(self->default_impl); + self->dict = NULL; + self->dispatcher_name = NULL; + self->public_name = NULL; + if (self->relevant_arg_func == Py_None) { /* NULL in the relevant arg function means we use `like=` */ Py_CLEAR(self->relevant_arg_func); } else { + /* Fetch names to clean up TypeErrors (show actual name) */ Py_INCREF(self->relevant_arg_func); + self->dispatcher_name = PyObject_GetAttrString( + self->relevant_arg_func, "__qualname__"); + if (self->dispatcher_name == NULL) { + Py_DECREF(self); + return NULL; + } + self->public_name = PyObject_GetAttrString( + self->default_impl, "__qualname__"); + if (self->public_name == NULL) { + Py_DECREF(self); + return NULL; + } } - Py_INCREF(self->default_impl); /* Need to be like a Python function that has arbitrary attributes */ self->dict = PyDict_New(); diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 25f551f6f..65155b207 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -359,6 +359,17 @@ class TestArrayFunctionImplementation: TypeError, "no implementation found for 'my.func'"): func(MyArray()) + @pytest.mark.parametrize("name", ["concatenate", "mean", "asarray"]) + def test_signature_error_message_simple(self, name): + func = getattr(np, name) + try: + # all of these functions need an argument: + func() + except TypeError as e: + exc = e + + assert exc.args[0].startswith(f"{name}()") + def test_signature_error_message(self): # The lambda function will be named "<lambda>", but the TypeError # should show the name as "func" @@ -370,7 +381,7 @@ class TestArrayFunctionImplementation: pass try: - func(bad_arg=3) + func._implementation(bad_arg=3) except TypeError as e: expected_exception = e |