summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-09-25 16:11:32 -0600
committerGitHub <noreply@github.com>2021-09-25 16:11:32 -0600
commitac78192390943d90ebae2f4e209e194914d0bc97 (patch)
tree87faf15c03d208e79dfe648f065c81ef195fbae0
parent05fcb6544f72ea173011a77cee14d72979ffe293 (diff)
parent8c89fef9e677afd3ee7777f242b6a53d3b7dfef4 (diff)
downloadnumpy-ac78192390943d90ebae2f4e209e194914d0bc97.tar.gz
Merge pull request #19879 from BvB93/cls_getitem
ENH: Add `__class_getitem__` to `ndarray`, `dtype` and `number`
-rw-r--r--doc/release/upcoming_changes/19879.new_feature.rst15
-rw-r--r--doc/source/reference/arrays.dtypes.rst7
-rw-r--r--doc/source/reference/arrays.ndarray.rst7
-rw-r--r--doc/source/reference/arrays.scalars.rst11
-rw-r--r--numpy/__init__.pyi12
-rw-r--r--numpy/core/_add_newdocs.py98
-rw-r--r--numpy/core/src/multiarray/descriptor.c38
-rw-r--r--numpy/core/src/multiarray/methods.c29
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src95
-rw-r--r--numpy/core/tests/test_arraymethod.py36
-rw-r--r--numpy/core/tests/test_dtype.py40
-rw-r--r--numpy/core/tests/test_scalar_methods.py55
-rw-r--r--numpy/typing/_generic_alias.py5
-rw-r--r--tools/refguide_check.py43
14 files changed, 462 insertions, 29 deletions
diff --git a/doc/release/upcoming_changes/19879.new_feature.rst b/doc/release/upcoming_changes/19879.new_feature.rst
new file mode 100644
index 000000000..c6624138b
--- /dev/null
+++ b/doc/release/upcoming_changes/19879.new_feature.rst
@@ -0,0 +1,15 @@
+``ndarray``, ``dtype`` and ``number`` are now runtime-subscriptable
+-------------------------------------------------------------------
+Mimicking :pep:`585`, the `~numpy.ndarray`, `~numpy.dtype` and `~numpy.number`
+classes are now subscriptable for python 3.9 and later.
+Consequently, expressions that were previously only allowed in .pyi stub files
+or with the help of ``from __future__ import annotations`` are now also legal
+during runtime.
+
+.. code-block:: python
+
+ >>> import numpy as np
+ >>> from typing import Any
+
+ >>> np.ndarray[Any, np.dtype[np.float64]]
+ numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]
diff --git a/doc/source/reference/arrays.dtypes.rst b/doc/source/reference/arrays.dtypes.rst
index b5ffa1a8b..34b0d7085 100644
--- a/doc/source/reference/arrays.dtypes.rst
+++ b/doc/source/reference/arrays.dtypes.rst
@@ -562,3 +562,10 @@ The following methods implement the pickle protocol:
dtype.__reduce__
dtype.__setstate__
+
+Utility method for typing:
+
+.. autosummary::
+ :toctree: generated/
+
+ dtype.__class_getitem__
diff --git a/doc/source/reference/arrays.ndarray.rst b/doc/source/reference/arrays.ndarray.rst
index f2204752d..7831b5f2c 100644
--- a/doc/source/reference/arrays.ndarray.rst
+++ b/doc/source/reference/arrays.ndarray.rst
@@ -621,3 +621,10 @@ String representations:
ndarray.__str__
ndarray.__repr__
+
+Utility method for typing:
+
+.. autosummary::
+ :toctree: generated/
+
+ ndarray.__class_getitem__
diff --git a/doc/source/reference/arrays.scalars.rst b/doc/source/reference/arrays.scalars.rst
index ccab0101e..c691e802f 100644
--- a/doc/source/reference/arrays.scalars.rst
+++ b/doc/source/reference/arrays.scalars.rst
@@ -196,10 +196,10 @@ Inexact types
``f16`` prints as ``0.1`` because it is as close to that value as possible,
whereas the other types do not as they have more precision and therefore have
closer values.
-
+
Conversely, floating-point scalars of different precisions which approximate
the same decimal value may compare unequal despite printing identically:
-
+
>>> f16 = np.float16("0.1")
>>> f32 = np.float32("0.1")
>>> f64 = np.float64("0.1")
@@ -498,6 +498,13 @@ The exceptions to the above rules are given below:
generic.__setstate__
generic.setflags
+Utility method for typing:
+
+.. autosummary::
+ :toctree: generated/
+
+ number.__class_getitem__
+
Defining new types
==================
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 09189a426..27d4a2ab0 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -9,6 +9,9 @@ from abc import abstractmethod
from types import TracebackType, MappingProxyType
from contextlib import ContextDecorator
+if sys.version_info >= (3, 9):
+ from types import GenericAlias
+
from numpy._pytesttester import PytestTester
from numpy.core.multiarray import flagsobj
from numpy.core._internal import _ctypes
@@ -1052,6 +1055,9 @@ class dtype(Generic[_DTypeScalar_co]):
copy: bool = ...,
) -> dtype[object_]: ...
+ if sys.version_info >= (3, 9):
+ def __class_getitem__(self, item: Any) -> GenericAlias: ...
+
@overload
def __getitem__(self: dtype[void], key: List[str]) -> dtype[void]: ...
@overload
@@ -1661,6 +1667,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
strides: None | _ShapeLike = ...,
order: _OrderKACF = ...,
) -> _ArraySelf: ...
+
+ if sys.version_info >= (3, 9):
+ def __class_getitem__(self, item: Any) -> GenericAlias: ...
+
@overload
def __array__(self, dtype: None = ..., /) -> ndarray[Any, _DType_co]: ...
@overload
@@ -2850,6 +2860,8 @@ class number(generic, Generic[_NBit1]): # type: ignore
def real(self: _ArraySelf) -> _ArraySelf: ...
@property
def imag(self: _ArraySelf) -> _ArraySelf: ...
+ if sys.version_info >= (3, 9):
+ def __class_getitem__(self, item: Any) -> GenericAlias: ...
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index 06f2a6376..6731b2e9d 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -796,7 +796,7 @@ add_newdoc('numpy.core.multiarray', 'array',
object : array_like
An array, any object exposing the array interface, an object whose
__array__ method returns an array, or any (nested) sequence.
- If object is a scalar, a 0-dimensional array containing object is
+ If object is a scalar, a 0-dimensional array containing object is
returned.
dtype : data-type, optional
The desired data-type for the array. If not given, then the type will
@@ -2201,8 +2201,8 @@ add_newdoc('numpy.core.multiarray', 'ndarray',
empty : Create an array, but leave its allocated memory unchanged (i.e.,
it contains "garbage").
dtype : Create a data-type.
- numpy.typing.NDArray : A :term:`generic <generic type>` version
- of ndarray.
+ numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
+ w.r.t. its `dtype.type <numpy.dtype.type>`.
Notes
-----
@@ -2798,6 +2798,39 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__copy__',
"""))
+add_newdoc('numpy.core.multiarray', 'ndarray', ('__class_getitem__',
+ """a.__class_getitem__(item, /)
+
+ Return a parametrized wrapper around the `~numpy.ndarray` type.
+
+ .. versionadded:: 1.22
+
+ Returns
+ -------
+ alias : types.GenericAlias
+ A parametrized `~numpy.ndarray` type.
+
+ Examples
+ --------
+ >>> from typing import Any
+ >>> import numpy as np
+
+ >>> np.ndarray[Any, np.dtype[Any]]
+ numpy.ndarray[typing.Any, numpy.dtype[Any]]
+
+ Note
+ ----
+ This method is only available for python 3.9 and later.
+
+ See Also
+ --------
+ :pep:`585` : Type hinting generics in standard collections.
+ numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
+ w.r.t. its `dtype.type <numpy.dtype.type>`.
+
+ """))
+
+
add_newdoc('numpy.core.multiarray', 'ndarray', ('__deepcopy__',
"""a.__deepcopy__(memo, /) -> Deep copy of array.
@@ -6044,6 +6077,35 @@ add_newdoc('numpy.core.multiarray', 'dtype', ('newbyteorder',
"""))
+add_newdoc('numpy.core.multiarray', 'dtype', ('__class_getitem__',
+ """
+ __class_getitem__(item, /)
+
+ Return a parametrized wrapper around the `~numpy.dtype` type.
+
+ .. versionadded:: 1.22
+
+ Returns
+ -------
+ alias : types.GenericAlias
+ A parametrized `~numpy.dtype` type.
+
+ Examples
+ --------
+ >>> import numpy as np
+
+ >>> np.dtype[np.int64]
+ numpy.dtype[numpy.int64]
+
+ Note
+ ----
+ This method is only available for python 3.9 and later.
+
+ See Also
+ --------
+ :pep:`585` : Type hinting generics in standard collections.
+
+ """))
##############################################################################
#
@@ -6465,6 +6527,36 @@ add_newdoc('numpy.core.numerictypes', 'generic',
add_newdoc('numpy.core.numerictypes', 'generic',
refer_to_array_attribute('view'))
+add_newdoc('numpy.core.numerictypes', 'number', ('__class_getitem__',
+ """
+ __class_getitem__(item, /)
+
+ Return a parametrized wrapper around the `~numpy.number` type.
+
+ .. versionadded:: 1.22
+
+ Returns
+ -------
+ alias : types.GenericAlias
+ A parametrized `~numpy.number` type.
+
+ Examples
+ --------
+ >>> from typing import Any
+ >>> import numpy as np
+
+ >>> np.signedinteger[Any]
+ numpy.signedinteger[typing.Any]
+
+ Note
+ ----
+ This method is only available for python 3.9 and later.
+
+ See Also
+ --------
+ :pep:`585` : Type hinting generics in standard collections.
+
+ """))
##############################################################################
#
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 397768f19..082876aa2 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -257,7 +257,7 @@ static PyArray_Descr *
_convert_from_tuple(PyObject *obj, int align)
{
if (PyTuple_GET_SIZE(obj) != 2) {
- PyErr_Format(PyExc_TypeError,
+ PyErr_Format(PyExc_TypeError,
"Tuple must have size 2, but has size %zd",
PyTuple_GET_SIZE(obj));
return NULL;
@@ -449,8 +449,8 @@ _convert_from_array_descr(PyObject *obj, int align)
for (int i = 0; i < n; i++) {
PyObject *item = PyList_GET_ITEM(obj, i);
if (!PyTuple_Check(item) || (PyTuple_GET_SIZE(item) < 2)) {
- PyErr_Format(PyExc_TypeError,
- "Field elements must be 2- or 3-tuples, got '%R'",
+ PyErr_Format(PyExc_TypeError,
+ "Field elements must be 2- or 3-tuples, got '%R'",
item);
goto fail;
}
@@ -461,7 +461,7 @@ _convert_from_array_descr(PyObject *obj, int align)
}
else if (PyTuple_Check(name)) {
if (PyTuple_GET_SIZE(name) != 2) {
- PyErr_Format(PyExc_TypeError,
+ PyErr_Format(PyExc_TypeError,
"If a tuple, the first element of a field tuple must have "
"two elements, not %zd",
PyTuple_GET_SIZE(name));
@@ -475,7 +475,7 @@ _convert_from_array_descr(PyObject *obj, int align)
}
}
else {
- PyErr_SetString(PyExc_TypeError,
+ PyErr_SetString(PyExc_TypeError,
"First element of field tuple is "
"neither a tuple nor str");
goto fail;
@@ -3101,6 +3101,30 @@ arraydescr_newbyteorder(PyArray_Descr *self, PyObject *args)
return (PyObject *)PyArray_DescrNewByteorder(self, endian);
}
+static PyObject *
+arraydescr_class_getitem(PyObject *cls, PyObject *args)
+{
+ PyObject *generic_alias;
+
+#ifdef Py_GENERICALIASOBJECT_H
+ Py_ssize_t args_len;
+
+ args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
+ if (args_len != 1) {
+ return PyErr_Format(PyExc_TypeError,
+ "Too %s arguments for %s",
+ args_len > 1 ? "many" : "few",
+ ((PyTypeObject *)cls)->tp_name);
+ }
+ generic_alias = Py_GenericAlias(cls, args);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "Type subscription requires python >= 3.9");
+ generic_alias = NULL;
+#endif
+ return generic_alias;
+}
+
static PyMethodDef arraydescr_methods[] = {
/* for pickling */
{"__reduce__",
@@ -3112,6 +3136,10 @@ static PyMethodDef arraydescr_methods[] = {
{"newbyteorder",
(PyCFunction)arraydescr_newbyteorder,
METH_VARARGS, NULL},
+ /* for typing; requires python >= 3.9 */
+ {"__class_getitem__",
+ (PyCFunction)arraydescr_class_getitem,
+ METH_CLASS | METH_O, NULL},
{NULL, NULL, 0, NULL} /* sentinel */
};
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 2c10817fa..391e65f6a 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -2699,6 +2699,30 @@ array_complex(PyArrayObject *self, PyObject *NPY_UNUSED(args))
return c;
}
+static PyObject *
+array_class_getitem(PyObject *cls, PyObject *args)
+{
+ PyObject *generic_alias;
+
+#ifdef Py_GENERICALIASOBJECT_H
+ Py_ssize_t args_len;
+
+ args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
+ if (args_len != 2) {
+ return PyErr_Format(PyExc_TypeError,
+ "Too %s arguments for %s",
+ args_len > 2 ? "many" : "few",
+ ((PyTypeObject *)cls)->tp_name);
+ }
+ generic_alias = Py_GenericAlias(cls, args);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "Type subscription requires python >= 3.9");
+ generic_alias = NULL;
+#endif
+ return generic_alias;
+}
+
NPY_NO_EXPORT PyMethodDef array_methods[] = {
/* for subtypes */
@@ -2756,6 +2780,11 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
(PyCFunction) array_format,
METH_VARARGS, NULL},
+ /* for typing; requires python >= 3.9 */
+ {"__class_getitem__",
+ (PyCFunction)array_class_getitem,
+ METH_CLASS | METH_O, NULL},
+
/* Original and Extended methods added 2005 */
{"all",
(PyCFunction)array_all,
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 4faa647ec..93cc9666e 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1805,6 +1805,59 @@ gentype_setflags(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args),
Py_RETURN_NONE;
}
+static PyObject *
+numbertype_class_getitem_abc(PyObject *cls, PyObject *args)
+{
+ PyObject *generic_alias;
+
+#ifdef Py_GENERICALIASOBJECT_H
+ Py_ssize_t args_len;
+ int args_len_expected;
+
+ /* complexfloating should take 2 parameters, all others take 1 */
+ if (PyType_IsSubtype((PyTypeObject *)cls,
+ &PyComplexFloatingArrType_Type)) {
+ args_len_expected = 2;
+ }
+ else {
+ args_len_expected = 1;
+ }
+
+ args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
+ if (args_len != args_len_expected) {
+ return PyErr_Format(PyExc_TypeError,
+ "Too %s arguments for %s",
+ args_len > args_len_expected ? "many" : "few",
+ ((PyTypeObject *)cls)->tp_name);
+ }
+ generic_alias = Py_GenericAlias(cls, args);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "Type subscription requires python >= 3.9");
+ generic_alias = NULL;
+#endif
+ return generic_alias;
+}
+
+/*
+ * Use for concrete np.number subclasses, making them act as if they
+ * were subtyped from e.g. np.signedinteger[object], thus lacking any
+ * free subscription parameters. Requires python >= 3.9.
+ */
+static PyObject *
+numbertype_class_getitem(PyObject *cls, PyObject *args)
+{
+#ifdef Py_GENERICALIASOBJECT_H
+ PyErr_Format(PyExc_TypeError,
+ "There are no type variables left in %s",
+ ((PyTypeObject *)cls)->tp_name);
+#else
+ PyErr_SetString(PyExc_TypeError,
+ "Type subscription requires python >= 3.9");
+#endif
+ return NULL;
+}
+
/*
* casting complex numbers (that don't inherit from Python complex)
* to Python complex
@@ -2188,6 +2241,14 @@ static PyGetSetDef inttype_getsets[] = {
{NULL, NULL, NULL, NULL, NULL}
};
+static PyMethodDef numbertype_methods[] = {
+ /* for typing; requires python >= 3.9 */
+ {"__class_getitem__",
+ (PyCFunction)numbertype_class_getitem_abc,
+ METH_CLASS | METH_O, NULL},
+ {NULL, NULL, 0, NULL} /* sentinel */
+};
+
/**begin repeat
* #name = cfloat,clongdouble#
*/
@@ -2195,6 +2256,10 @@ static PyMethodDef @name@type_methods[] = {
{"__complex__",
(PyCFunction)@name@_complex,
METH_VARARGS | METH_KEYWORDS, NULL},
+ /* for typing; requires python >= 3.9 */
+ {"__class_getitem__",
+ (PyCFunction)numbertype_class_getitem,
+ METH_CLASS | METH_O, NULL},
{NULL, NULL, 0, NULL}
};
/**end repeat**/
@@ -2232,6 +2297,23 @@ static PyMethodDef @name@type_methods[] = {
{"is_integer",
(PyCFunction)@name@_is_integer,
METH_NOARGS, NULL},
+ /* for typing; requires python >= 3.9 */
+ {"__class_getitem__",
+ (PyCFunction)numbertype_class_getitem,
+ METH_CLASS | METH_O, NULL},
+ {NULL, NULL, 0, NULL}
+};
+/**end repeat**/
+
+/**begin repeat
+ * #name = byte, short, int, long, longlong, ubyte, ushort,
+ * uint, ulong, ulonglong, timedelta, cdouble#
+ */
+static PyMethodDef @name@type_methods[] = {
+ /* for typing; requires python >= 3.9 */
+ {"__class_getitem__",
+ (PyCFunction)numbertype_class_getitem,
+ METH_CLASS | METH_O, NULL},
{NULL, NULL, 0, NULL}
};
/**end repeat**/
@@ -3951,6 +4033,8 @@ initialize_numeric_types(void)
PyIntegerArrType_Type.tp_getset = inttype_getsets;
+ PyNumberArrType_Type.tp_methods = numbertype_methods;
+
/**begin repeat
* #NAME= Number, Integer, SignedInteger, UnsignedInteger, Inexact,
* Floating, ComplexFloating, Flexible, Character#
@@ -4016,6 +4100,17 @@ initialize_numeric_types(void)
/**end repeat**/
+ /**begin repeat
+ * #name = byte, short, int, long, longlong, ubyte, ushort,
+ * uint, ulong, ulonglong, timedelta, cdouble#
+ * #Name = Byte, Short, Int, Long, LongLong, UByte, UShort,
+ * UInt, ULong, ULongLong, Timedelta, CDouble#
+ */
+
+ Py@Name@ArrType_Type.tp_methods = @name@type_methods;
+
+ /**end repeat**/
+
/* We won't be inheriting from Python Int type. */
PyIntArrType_Type.tp_hash = int_arrtype_hash;
diff --git a/numpy/core/tests/test_arraymethod.py b/numpy/core/tests/test_arraymethod.py
index b1bc79b80..49aa9f6df 100644
--- a/numpy/core/tests/test_arraymethod.py
+++ b/numpy/core/tests/test_arraymethod.py
@@ -3,6 +3,10 @@ This file tests the generic aspects of ArrayMethod. At the time of writing
this is private API, but when added, public API may be added here.
"""
+import sys
+import types
+from typing import Any, Type
+
import pytest
import numpy as np
@@ -56,3 +60,35 @@ class TestSimpleStridedCall:
# This is private API, which may be modified freely
with pytest.raises(error):
self.method._simple_strided_call(*args)
+
+
+@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
+class TestClassGetItem:
+ @pytest.mark.parametrize(
+ "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
+ )
+ def test_class_getitem(self, cls: Type[np.ndarray]) -> None:
+ """Test `ndarray.__class_getitem__`."""
+ alias = cls[Any, Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is cls
+
+ @pytest.mark.parametrize("arg_len", range(4))
+ def test_subscript_tuple(self, arg_len: int) -> None:
+ arg_tup = (Any,) * arg_len
+ if arg_len == 2:
+ assert np.ndarray[arg_tup]
+ else:
+ with pytest.raises(TypeError):
+ np.ndarray[arg_tup]
+
+ def test_subscript_scalar(self) -> None:
+ with pytest.raises(TypeError):
+ np.ndarray[Any]
+
+
+@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
+def test_class_getitem_38() -> None:
+ match = "Type subscription requires python >= 3.9"
+ with pytest.raises(TypeError, match=match):
+ np.ndarray[Any, Any]
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 23269f01b..db4f275b5 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -4,6 +4,8 @@ import pytest
import ctypes
import gc
import warnings
+import types
+from typing import Any
import numpy as np
from numpy.core._rational_tests import rational
@@ -111,9 +113,9 @@ class TestBuiltin:
@pytest.mark.parametrize("dtype",
['Bool', 'Bytes0', 'Complex32', 'Complex64',
'Datetime64', 'Float16', 'Float32', 'Float64',
- 'Int8', 'Int16', 'Int32', 'Int64',
+ 'Int8', 'Int16', 'Int32', 'Int64',
'Object0', 'Str0', 'Timedelta64',
- 'UInt8', 'UInt16', 'Uint32', 'UInt32',
+ 'UInt8', 'UInt16', 'Uint32', 'UInt32',
'Uint64', 'UInt64', 'Void0',
"Float128", "Complex128"])
def test_numeric_style_types_are_invalid(self, dtype):
@@ -1549,3 +1551,37 @@ class TestUserDType:
# Tests that a dtype must have its type field set up to np.dtype
# or in this case a builtin instance.
create_custom_field_dtype(blueprint, mytype, 2)
+
+
+@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
+class TestClassGetItem:
+ def test_dtype(self) -> None:
+ alias = np.dtype[Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is np.dtype
+
+ @pytest.mark.parametrize("code", np.typecodes["All"])
+ def test_dtype_subclass(self, code: str) -> None:
+ cls = type(np.dtype(code))
+ alias = cls[Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is cls
+
+ @pytest.mark.parametrize("arg_len", range(4))
+ def test_subscript_tuple(self, arg_len: int) -> None:
+ arg_tup = (Any,) * arg_len
+ if arg_len == 1:
+ assert np.dtype[arg_tup]
+ else:
+ with pytest.raises(TypeError):
+ np.dtype[arg_tup]
+
+ def test_subscript_scalar(self) -> None:
+ assert np.dtype[Any]
+
+
+@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
+def test_class_getitem_38() -> None:
+ match = "Type subscription requires python >= 3.9"
+ with pytest.raises(TypeError, match=match):
+ np.dtype[Any]
diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py
index 94b2dd3c9..6077c8f75 100644
--- a/numpy/core/tests/test_scalar_methods.py
+++ b/numpy/core/tests/test_scalar_methods.py
@@ -1,8 +1,11 @@
"""
Test the scalar constructors, which also do type-coercion
"""
+import sys
import fractions
import platform
+import types
+from typing import Any, Type
import pytest
import numpy as np
@@ -128,3 +131,55 @@ class TestIsInteger:
if value == 0:
continue
assert not value.is_integer()
+
+
+@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
+class TestClassGetItem:
+ @pytest.mark.parametrize("cls", [
+ np.number,
+ np.integer,
+ np.inexact,
+ np.unsignedinteger,
+ np.signedinteger,
+ np.floating,
+ ])
+ def test_abc(self, cls: Type[np.number]) -> None:
+ alias = cls[Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is cls
+
+ def test_abc_complexfloating(self) -> None:
+ alias = np.complexfloating[Any, Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is np.complexfloating
+
+ @pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character])
+ def test_abc_non_numeric(self, cls: Type[np.generic]) -> None:
+ with pytest.raises(TypeError):
+ cls[Any]
+
+ @pytest.mark.parametrize("code", np.typecodes["All"])
+ def test_concrete(self, code: str) -> None:
+ cls = np.dtype(code).type
+ with pytest.raises(TypeError):
+ cls[Any]
+
+ @pytest.mark.parametrize("arg_len", range(4))
+ def test_subscript_tuple(self, arg_len: int) -> None:
+ arg_tup = (Any,) * arg_len
+ if arg_len == 1:
+ assert np.number[arg_tup]
+ else:
+ with pytest.raises(TypeError):
+ np.number[arg_tup]
+
+ def test_subscript_scalar(self) -> None:
+ assert np.number[Any]
+
+
+@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
+@pytest.mark.parametrize("cls", [np.number, np.complexfloating, np.int64])
+def test_class_getitem_38(cls: Type[np.number]) -> None:
+ match = "Type subscription requires python >= 3.9"
+ with pytest.raises(TypeError, match=match):
+ cls[Any]
diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py
index 5ad5e580c..932f12dd0 100644
--- a/numpy/typing/_generic_alias.py
+++ b/numpy/typing/_generic_alias.py
@@ -205,12 +205,9 @@ else:
ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
-if TYPE_CHECKING:
+if TYPE_CHECKING or sys.version_info >= (3, 9):
_DType = np.dtype[ScalarType]
NDArray = np.ndarray[Any, np.dtype[ScalarType]]
-elif sys.version_info >= (3, 9):
- _DType = types.GenericAlias(np.dtype, (ScalarType,))
- NDArray = types.GenericAlias(np.ndarray, (Any, _DType))
else:
_DType = _GenericAlias(np.dtype, (ScalarType,))
NDArray = _GenericAlias(np.ndarray, (Any, _DType))
diff --git a/tools/refguide_check.py b/tools/refguide_check.py
index a6bfc0fe4..21ba5a448 100644
--- a/tools/refguide_check.py
+++ b/tools/refguide_check.py
@@ -93,18 +93,27 @@ OTHER_MODULE_DOCS = {
# these names are known to fail doctesting and we like to keep it that way
# e.g. sometimes pseudocode is acceptable etc
-DOCTEST_SKIPLIST = set([
+#
+# Optionally, a subset of methods can be skipped by setting dict-values
+# to a container of method-names
+DOCTEST_SKIPDICT = {
# cases where NumPy docstrings import things from SciPy:
- 'numpy.lib.vectorize',
- 'numpy.random.standard_gamma',
- 'numpy.random.gamma',
- 'numpy.random.vonmises',
- 'numpy.random.power',
- 'numpy.random.zipf',
+ 'numpy.lib.vectorize': None,
+ 'numpy.random.standard_gamma': None,
+ 'numpy.random.gamma': None,
+ 'numpy.random.vonmises': None,
+ 'numpy.random.power': None,
+ 'numpy.random.zipf': None,
# remote / local file IO with DataSource is problematic in doctest:
- 'numpy.lib.DataSource',
- 'numpy.lib.Repository',
-])
+ 'numpy.lib.DataSource': None,
+ 'numpy.lib.Repository': None,
+}
+if sys.version_info < (3, 9):
+ DOCTEST_SKIPDICT.update({
+ "numpy.core.ndarray": {"__class_getitem__"},
+ "numpy.core.dtype": {"__class_getitem__"},
+ "numpy.core.number": {"__class_getitem__"},
+ })
# Skip non-numpy RST files, historical release notes
# Any single-directory exact match will skip the directory and all subdirs.
@@ -869,8 +878,12 @@ def check_doctests(module, verbose, ns=None,
for name in get_all_dict(module)[0]:
full_name = module.__name__ + '.' + name
- if full_name in DOCTEST_SKIPLIST:
- continue
+ if full_name in DOCTEST_SKIPDICT:
+ skip_methods = DOCTEST_SKIPDICT[full_name]
+ if skip_methods is None:
+ continue
+ else:
+ skip_methods = None
try:
obj = getattr(module, name)
@@ -891,6 +904,10 @@ def check_doctests(module, verbose, ns=None,
traceback.format_exc()))
continue
+ if skip_methods is not None:
+ tests = [i for i in tests if
+ i.name.partition(".")[2] not in skip_methods]
+
success, output = _run_doctests(tests, full_name, verbose,
doctest_warnings)
@@ -971,7 +988,7 @@ def check_doctests_testfile(fname, verbose, ns=None,
results = []
_, short_name = os.path.split(fname)
- if short_name in DOCTEST_SKIPLIST:
+ if short_name in DOCTEST_SKIPDICT:
return results
full_name = fname