summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-09-16 12:15:57 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-09-16 13:12:17 +0200
commiteba93e9d7b64aa9435b12b9fce0ddc1155cc8dc5 (patch)
tree0cb55287239e53d7e45bccb4e3b15c9bb2ffed9c
parent0baeeb163e008a829d195fba48f750f7b517ec61 (diff)
downloadnumpy-eba93e9d7b64aa9435b12b9fce0ddc1155cc8dc5.tar.gz
MAINT: Make `__class_getitem__` available to all python version and perform basic validation of its input arguments
It will still raise on python 3.8, but now with a more explicit exception message
-rw-r--r--numpy/core/_add_newdocs.py136
-rw-r--r--numpy/core/src/multiarray/descriptor.c29
-rw-r--r--numpy/core/src/multiarray/methods.c28
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src48
-rw-r--r--numpy/core/tests/test_arraymethod.py30
-rw-r--r--numpy/core/tests/test_dtype.py12
-rw-r--r--numpy/core/tests/test_scalar_methods.py12
7 files changed, 196 insertions, 99 deletions
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index d758ef724..6731b2e9d 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -9,7 +9,6 @@ NOTE: Many of the methods of ndarray have corresponding functions.
"""
-import sys
from numpy.core.function_base import add_newdoc
from numpy.core.overrides import array_function_like_doc
@@ -2799,38 +2798,37 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__copy__',
"""))
-if sys.version_info > (3, 9):
- add_newdoc('numpy.core.multiarray', 'ndarray', ('__class_getitem__',
- """a.__class_getitem__(item, /)
+add_newdoc('numpy.core.multiarray', 'ndarray', ('__class_getitem__',
+ """a.__class_getitem__(item, /)
- Return a parametrized wrapper around the `~numpy.ndarray` type.
+ Return a parametrized wrapper around the `~numpy.ndarray` type.
- .. versionadded:: 1.22
+ .. versionadded:: 1.22
- Returns
- -------
- alias : types.GenericAlias
- A parametrized `~numpy.ndarray` type.
+ Returns
+ -------
+ alias : types.GenericAlias
+ A parametrized `~numpy.ndarray` type.
- Examples
- --------
- >>> from typing import Any
- >>> import numpy as np
+ Examples
+ --------
+ >>> from typing import Any
+ >>> import numpy as np
- >>> np.ndarray[Any, np.dtype]
- numpy.ndarray[typing.Any, numpy.dtype]
+ >>> np.ndarray[Any, np.dtype[Any]]
+ numpy.ndarray[typing.Any, numpy.dtype[Any]]
- Note
- ----
- This method is only available for python 3.9 and later.
+ Note
+ ----
+ This method is only available for python 3.9 and later.
- See Also
- --------
- :pep:`585` : Type hinting generics in standard collections.
- numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
- w.r.t. its `dtype.type <numpy.dtype.type>`.
+ See Also
+ --------
+ :pep:`585` : Type hinting generics in standard collections.
+ numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
+ w.r.t. its `dtype.type <numpy.dtype.type>`.
- """))
+ """))
add_newdoc('numpy.core.multiarray', 'ndarray', ('__deepcopy__',
@@ -6079,36 +6077,35 @@ add_newdoc('numpy.core.multiarray', 'dtype', ('newbyteorder',
"""))
-if sys.version_info >= (3, 9):
- add_newdoc('numpy.core.multiarray', 'dtype', ('__class_getitem__',
- """
- __class_getitem__(item, /)
+add_newdoc('numpy.core.multiarray', 'dtype', ('__class_getitem__',
+ """
+ __class_getitem__(item, /)
- Return a parametrized wrapper around the `~numpy.dtype` type.
+ Return a parametrized wrapper around the `~numpy.dtype` type.
- .. versionadded:: 1.22
+ .. versionadded:: 1.22
- Returns
- -------
- alias : types.GenericAlias
- A parametrized `~numpy.dtype` type.
+ Returns
+ -------
+ alias : types.GenericAlias
+ A parametrized `~numpy.dtype` type.
- Examples
- --------
- >>> import numpy as np
+ Examples
+ --------
+ >>> import numpy as np
- >>> np.dtype[np.int64]
- numpy.dtype[numpy.int64]
+ >>> np.dtype[np.int64]
+ numpy.dtype[numpy.int64]
- Note
- ----
- This method is only available for python 3.9 and later.
+ Note
+ ----
+ This method is only available for python 3.9 and later.
- See Also
- --------
- :pep:`585` : Type hinting generics in standard collections.
+ See Also
+ --------
+ :pep:`585` : Type hinting generics in standard collections.
- """))
+ """))
##############################################################################
#
@@ -6530,37 +6527,36 @@ add_newdoc('numpy.core.numerictypes', 'generic',
add_newdoc('numpy.core.numerictypes', 'generic',
refer_to_array_attribute('view'))
-if sys.version_info >= (3, 9):
- add_newdoc('numpy.core.numerictypes', 'number', ('__class_getitem__',
- """
- __class_getitem__(item, /)
+add_newdoc('numpy.core.numerictypes', 'number', ('__class_getitem__',
+ """
+ __class_getitem__(item, /)
- Return a parametrized wrapper around the `~numpy.number` type.
+ Return a parametrized wrapper around the `~numpy.number` type.
- .. versionadded:: 1.22
+ .. versionadded:: 1.22
- Returns
- -------
- alias : types.GenericAlias
- A parametrized `~numpy.number` type.
+ Returns
+ -------
+ alias : types.GenericAlias
+ A parametrized `~numpy.number` type.
- Examples
- --------
- >>> from typing import Any
- >>> import numpy as np
+ Examples
+ --------
+ >>> from typing import Any
+ >>> import numpy as np
- >>> np.signedinteger[Any]
- numpy.signedinteger[typing.Any]
+ >>> np.signedinteger[Any]
+ numpy.signedinteger[typing.Any]
- Note
- ----
- This method is only available for python 3.9 and later.
+ Note
+ ----
+ This method is only available for python 3.9 and later.
- See Also
- --------
- :pep:`585` : Type hinting generics in standard collections.
+ See Also
+ --------
+ :pep:`585` : Type hinting generics in standard collections.
- """))
+ """))
##############################################################################
#
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index d55664927..082876aa2 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -3101,6 +3101,30 @@ arraydescr_newbyteorder(PyArray_Descr *self, PyObject *args)
return (PyObject *)PyArray_DescrNewByteorder(self, endian);
}
+static PyObject *
+arraydescr_class_getitem(PyObject *cls, PyObject *args)
+{
+ PyObject *generic_alias;
+
+#ifdef Py_GENERICALIASOBJECT_H
+ Py_ssize_t args_len;
+
+ args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
+ if (args_len != 1) {
+ return PyErr_Format(PyExc_TypeError,
+ "Too %s arguments for %s",
+ args_len > 1 ? "many" : "few",
+ ((PyTypeObject *)cls)->tp_name);
+ }
+ generic_alias = Py_GenericAlias(cls, args);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "Type subscription requires python >= 3.9");
+ generic_alias = NULL;
+#endif
+ return generic_alias;
+}
+
static PyMethodDef arraydescr_methods[] = {
/* for pickling */
{"__reduce__",
@@ -3112,13 +3136,10 @@ static PyMethodDef arraydescr_methods[] = {
{"newbyteorder",
(PyCFunction)arraydescr_newbyteorder,
METH_VARARGS, NULL},
-
/* for typing; requires python >= 3.9 */
- #ifdef Py_GENERICALIASOBJECT_H
{"__class_getitem__",
- (PyCFunction)Py_GenericAlias,
+ (PyCFunction)arraydescr_class_getitem,
METH_CLASS | METH_O, NULL},
- #endif
{NULL, NULL, 0, NULL} /* sentinel */
};
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 43167cbbf..391e65f6a 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -2699,6 +2699,30 @@ array_complex(PyArrayObject *self, PyObject *NPY_UNUSED(args))
return c;
}
+static PyObject *
+array_class_getitem(PyObject *cls, PyObject *args)
+{
+ PyObject *generic_alias;
+
+#ifdef Py_GENERICALIASOBJECT_H
+ Py_ssize_t args_len;
+
+ args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
+ if (args_len != 2) {
+ return PyErr_Format(PyExc_TypeError,
+ "Too %s arguments for %s",
+ args_len > 2 ? "many" : "few",
+ ((PyTypeObject *)cls)->tp_name);
+ }
+ generic_alias = Py_GenericAlias(cls, args);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "Type subscription requires python >= 3.9");
+ generic_alias = NULL;
+#endif
+ return generic_alias;
+}
+
NPY_NO_EXPORT PyMethodDef array_methods[] = {
/* for subtypes */
@@ -2757,11 +2781,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
METH_VARARGS, NULL},
/* for typing; requires python >= 3.9 */
- #ifdef Py_GENERICALIASOBJECT_H
{"__class_getitem__",
- (PyCFunction)Py_GenericAlias,
+ (PyCFunction)array_class_getitem,
METH_CLASS | METH_O, NULL},
- #endif
/* Original and Extended methods added 2005 */
{"all",
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 328581536..cacf4485e 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1805,20 +1805,48 @@ gentype_setflags(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args),
Py_RETURN_NONE;
}
+static PyObject *
+numbertype_class_getitem_abc(PyObject *cls, PyObject *args)
+{
+ PyObject *generic_alias;
+
+#ifdef Py_GENERICALIASOBJECT_H
+ Py_ssize_t args_len;
+
+ args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
+ if (args_len != 1) {
+ return PyErr_Format(PyExc_TypeError,
+ "Too %s arguments for %s",
+ args_len > 1 ? "many" : "few",
+ ((PyTypeObject *)cls)->tp_name);
+ }
+ generic_alias = Py_GenericAlias(cls, args);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "Type subscription requires python >= 3.9");
+ generic_alias = NULL;
+#endif
+ return generic_alias;
+}
+
/*
* Use for concrete np.number subclasses, making them act as if they
* were subtyped from e.g. np.signedinteger[object], thus lacking any
* free subscription parameters. Requires python >= 3.9.
*/
-#ifdef Py_GENERICALIASOBJECT_H
static PyObject *
numbertype_class_getitem(PyObject *cls, PyObject *args)
{
- return PyErr_Format(PyExc_TypeError,
- "There are no type variables left in %s",
- ((PyTypeObject *)cls)->tp_name);
-}
+#ifdef Py_GENERICALIASOBJECT_H
+ PyErr_Format(PyExc_TypeError,
+ "There are no type variables left in %s",
+ ((PyTypeObject *)cls)->tp_name);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "Type subscription requires python >= 3.9");
#endif
+ return NULL;
+}
/*
* casting complex numbers (that don't inherit from Python complex)
@@ -2205,11 +2233,9 @@ static PyGetSetDef inttype_getsets[] = {
static PyMethodDef numbertype_methods[] = {
/* for typing; requires python >= 3.9 */
- #ifdef Py_GENERICALIASOBJECT_H
{"__class_getitem__",
- (PyCFunction)Py_GenericAlias,
+ (PyCFunction)numbertype_class_getitem_abc,
METH_CLASS | METH_O, NULL},
- #endif
{NULL, NULL, 0, NULL} /* sentinel */
};
@@ -2221,11 +2247,9 @@ static PyMethodDef @name@type_methods[] = {
(PyCFunction)@name@_complex,
METH_VARARGS | METH_KEYWORDS, NULL},
/* for typing; requires python >= 3.9 */
- #ifdef Py_GENERICALIASOBJECT_H
{"__class_getitem__",
(PyCFunction)numbertype_class_getitem,
METH_CLASS | METH_O, NULL},
- #endif
{NULL, NULL, 0, NULL}
};
/**end repeat**/
@@ -2264,11 +2288,9 @@ static PyMethodDef @name@type_methods[] = {
(PyCFunction)@name@_is_integer,
METH_NOARGS, NULL},
/* for typing; requires python >= 3.9 */
- #ifdef Py_GENERICALIASOBJECT_H
{"__class_getitem__",
(PyCFunction)numbertype_class_getitem,
METH_CLASS | METH_O, NULL},
- #endif
{NULL, NULL, 0, NULL}
};
/**end repeat**/
@@ -2279,11 +2301,9 @@ static PyMethodDef @name@type_methods[] = {
*/
static PyMethodDef @name@type_methods[] = {
/* for typing; requires python >= 3.9 */
- #ifdef Py_GENERICALIASOBJECT_H
{"__class_getitem__",
(PyCFunction)numbertype_class_getitem,
METH_CLASS | METH_O, NULL},
- #endif
{NULL, NULL, 0, NULL}
};
/**end repeat**/
diff --git a/numpy/core/tests/test_arraymethod.py b/numpy/core/tests/test_arraymethod.py
index 9bd4c54df..1e5db5915 100644
--- a/numpy/core/tests/test_arraymethod.py
+++ b/numpy/core/tests/test_arraymethod.py
@@ -62,12 +62,26 @@ class TestSimpleStridedCall:
self.method._simple_strided_call(*args)
-@pytest.mark.parametrize(
- "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
-)
@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
-def test_class_getitem(cls: Type[np.ndarray]) -> None:
- """Test `ndarray.__class_getitem__`."""
- alias = cls[Any, Any]
- assert isinstance(alias, types.GenericAlias)
- assert alias.__origin__ is cls
+class TestClassGetItem:
+ @pytest.mark.parametrize(
+ "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
+ )
+ def test_class_getitem(self, cls: Type[np.ndarray]) -> None:
+ """Test `ndarray.__class_getitem__`."""
+ alias = cls[Any, Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is cls
+
+ @pytest.mark.parametrize("arg_len", range(4))
+ def test_subscript_tuple(self, arg_len: int) -> None:
+ arg_tup = (Any,) * arg_len
+ if arg_len == 2:
+ assert np.ndarray[arg_tup]
+ else:
+ with pytest.raises(TypeError):
+ np.ndarray[arg_tup]
+
+ def test_subscript_scalar(self) -> None:
+ with pytest.raises(TypeError):
+ np.ndarray[Any]
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 40f17c09c..b438c1e8c 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -1566,3 +1566,15 @@ class TestClassGetItem:
alias = cls[Any]
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is cls
+
+ @pytest.mark.parametrize("arg_len", range(4))
+ def test_subscript_tuple(self, arg_len: int) -> None:
+ arg_tup = (Any,) * arg_len
+ if arg_len == 1:
+ assert np.dtype[arg_tup]
+ else:
+ with pytest.raises(TypeError):
+ np.dtype[arg_tup]
+
+ def test_subscript_scalar(self) -> None:
+ assert np.dtype[Any]
diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py
index 0cdfe99b1..ad22697b2 100644
--- a/numpy/core/tests/test_scalar_methods.py
+++ b/numpy/core/tests/test_scalar_methods.py
@@ -159,3 +159,15 @@ class TestClassGetItem:
cls = np.dtype(code).type
with pytest.raises(TypeError):
cls[Any]
+
+ @pytest.mark.parametrize("arg_len", range(4))
+ def test_subscript_tuple(self, arg_len: int) -> None:
+ arg_tup = (Any,) * arg_len
+ if arg_len == 1:
+ assert np.number[arg_tup]
+ else:
+ with pytest.raises(TypeError):
+ np.number[arg_tup]
+
+ def test_subscript_scalar(self) -> None:
+ assert np.number[Any]