From 61610e74340a4a22f2782274600ae34bd882b929 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 25 Apr 2023 12:22:53 +0200 Subject: ENH: Restore TypeError cleanup in array function dispatching When the dispathcer raises a TypeError and it starts with the dispatchers name (or actually __qualname__ not that it normally matters), then it is nicer for users if we just raise a new error with the public symbol name. Python does not seem to normalize exception and goes down the unicode path, but I assume that e.g. PyPy may not do that. And there might be other weirder reason why we go down the full path. I have manually tested it by forcing Normalization. Closes gh-23029 --- numpy/core/src/multiarray/arrayfunction_override.c | 82 +++++++++++++++++++++- numpy/core/tests/test_overrides.py | 13 +++- 2 files changed, 93 insertions(+), 2 deletions(-) (limited to 'numpy/core') diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c index 04768504e..63d109ecb 100644 --- a/numpy/core/src/multiarray/arrayfunction_override.c +++ b/numpy/core/src/multiarray/arrayfunction_override.c @@ -419,6 +419,9 @@ typedef struct { PyObject *dict; PyObject *relevant_arg_func; PyObject *default_impl; + /* The following fields are used to clean up TypeError messages only: */ + PyObject *dispatcher_name; + PyObject *public_name; } PyArray_ArrayFunctionDispatcherObject; @@ -428,10 +431,69 @@ dispatcher_dealloc(PyArray_ArrayFunctionDispatcherObject *self) Py_CLEAR(self->relevant_arg_func); Py_CLEAR(self->default_impl); Py_CLEAR(self->dict); + Py_CLEAR(self->dispatcher_name); + Py_CLEAR(self->public_name); PyObject_FREE(self); } +static void +fix_name_if_typeerror(PyArray_ArrayFunctionDispatcherObject *self) +{ + if (!PyErr_ExceptionMatches(PyExc_TypeError)) { + return; + } + + PyObject *exc, *val, *tb, *message; + PyErr_Fetch(&exc, &val, &tb); + + if (!PyUnicode_CheckExact(val)) { + /* + * We expect the error to be unnormalized, but maybe it isn't always + * the case, so normalize and fetch args[0] if it isn't a string. + */ + PyErr_NormalizeException(&exc, &val, &tb); + + PyObject *args = PyObject_GetAttrString(val, "args"); + if (args == NULL || !PyTuple_CheckExact(args) + || PyTuple_GET_SIZE(args) != 1) { + Py_XDECREF(args); + goto restore_error; + } + message = PyTuple_GET_ITEM(args, 0); + Py_INCREF(message); + Py_DECREF(args); + if (!PyUnicode_CheckExact(message)) { + Py_DECREF(message); + goto restore_error; + } + } + else { + Py_INCREF(val); + message = val; + } + + Py_ssize_t cmp = PyUnicode_Tailmatch( + message, self->dispatcher_name, 0, -1, 0); + if (cmp <= 0) { + Py_DECREF(message); + goto restore_error; + } + Py_SETREF(message, PyUnicode_Replace( + message, self->dispatcher_name, self->public_name, 1)); + if (message == NULL) { + goto restore_error; + } + PyErr_SetObject(PyExc_TypeError, message); + Py_DECREF(message); + return; + + restore_error: + /* replacement not successful, so restore original error */ + PyErr_Restore(exc, val, tb); +} + + static PyObject * dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) @@ -458,6 +520,7 @@ dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self, relevant_args = PyObject_Vectorcall( self->relevant_arg_func, args, len_args, kwnames); if (relevant_args == NULL) { + fix_name_if_typeerror(self); return NULL; } Py_SETREF(relevant_args, PySequence_Fast(relevant_args, @@ -600,14 +663,31 @@ dispatcher_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwargs) } self->vectorcall = (vectorcallfunc)dispatcher_vectorcall; + Py_INCREF(self->default_impl); + self->dict = NULL; + self->dispatcher_name = NULL; + self->public_name = NULL; + if (self->relevant_arg_func == Py_None) { /* NULL in the relevant arg function means we use `like=` */ Py_CLEAR(self->relevant_arg_func); } else { + /* Fetch names to clean up TypeErrors (show actual name) */ Py_INCREF(self->relevant_arg_func); + self->dispatcher_name = PyObject_GetAttrString( + self->relevant_arg_func, "__qualname__"); + if (self->dispatcher_name == NULL) { + Py_DECREF(self); + return NULL; + } + self->public_name = PyObject_GetAttrString( + self->default_impl, "__qualname__"); + if (self->public_name == NULL) { + Py_DECREF(self); + return NULL; + } } - Py_INCREF(self->default_impl); /* Need to be like a Python function that has arbitrary attributes */ self->dict = PyDict_New(); diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 25f551f6f..65155b207 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -359,6 +359,17 @@ class TestArrayFunctionImplementation: TypeError, "no implementation found for 'my.func'"): func(MyArray()) + @pytest.mark.parametrize("name", ["concatenate", "mean", "asarray"]) + def test_signature_error_message_simple(self, name): + func = getattr(np, name) + try: + # all of these functions need an argument: + func() + except TypeError as e: + exc = e + + assert exc.args[0].startswith(f"{name}()") + def test_signature_error_message(self): # The lambda function will be named "", but the TypeError # should show the name as "func" @@ -370,7 +381,7 @@ class TestArrayFunctionImplementation: pass try: - func(bad_arg=3) + func._implementation(bad_arg=3) except TypeError as e: expected_exception = e -- cgit v1.2.1 From 5019e0abc3eda9bbfbead97b08b4302da2c31437 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 25 Apr 2023 12:59:04 +0200 Subject: TST: Skip test on older Python versions which use `__name__` --- numpy/core/tests/test_overrides.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'numpy/core') diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 65155b207..5924358ea 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -389,6 +389,12 @@ class TestArrayFunctionImplementation: func(bad_arg=3) raise AssertionError("must fail") except TypeError as exc: + if exc.args[0].startswith("_dispatcher"): + # We replace the qualname currently, but it used `__name__` + # (relevant functions have the same name and qualname anyway) + pytest.skip("Python version is not using __qualname__ for " + "TypeError formatting.") + assert exc.args == expected_exception.args @pytest.mark.parametrize("value", [234, "this func is not replaced"]) -- cgit v1.2.1 From b4313643c052abb1c7966fbe42b2ae9c17259b59 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 25 Apr 2023 15:54:19 +0200 Subject: BUG: Add missing decref's of replaced error. --- numpy/core/src/multiarray/arrayfunction_override.c | 3 +++ 1 file changed, 3 insertions(+) (limited to 'numpy/core') diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c index 63d109ecb..08d386e8b 100644 --- a/numpy/core/src/multiarray/arrayfunction_override.c +++ b/numpy/core/src/multiarray/arrayfunction_override.c @@ -485,6 +485,9 @@ fix_name_if_typeerror(PyArray_ArrayFunctionDispatcherObject *self) goto restore_error; } PyErr_SetObject(PyExc_TypeError, message); + Py_DECREF(exc); + Py_XDECREF(val); + Py_XDECREF(tb); Py_DECREF(message); return; -- cgit v1.2.1 From 8b7f69ceae5cd99592f79121a1bd7b014af4833c Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 25 Apr 2023 19:12:28 +0200 Subject: MAINT: Seems it should be -1 direction for matching a prefix Not that it mattered, but docs say direction should be either -1 or 1 --- numpy/core/src/multiarray/arrayfunction_override.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/core') diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c index 08d386e8b..3c55e2164 100644 --- a/numpy/core/src/multiarray/arrayfunction_override.c +++ b/numpy/core/src/multiarray/arrayfunction_override.c @@ -474,13 +474,13 @@ fix_name_if_typeerror(PyArray_ArrayFunctionDispatcherObject *self) } Py_ssize_t cmp = PyUnicode_Tailmatch( - message, self->dispatcher_name, 0, -1, 0); + message, self->dispatcher_name, 0, -1, -1); if (cmp <= 0) { Py_DECREF(message); goto restore_error; } Py_SETREF(message, PyUnicode_Replace( - message, self->dispatcher_name, self->public_name, 1)); + message, self->dispatcher_name, self->public_name, 1)); if (message == NULL) { goto restore_error; } -- cgit v1.2.1