diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-09-25 16:11:32 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-25 16:11:32 -0600 |
commit | ac78192390943d90ebae2f4e209e194914d0bc97 (patch) | |
tree | 87faf15c03d208e79dfe648f065c81ef195fbae0 | |
parent | 05fcb6544f72ea173011a77cee14d72979ffe293 (diff) | |
parent | 8c89fef9e677afd3ee7777f242b6a53d3b7dfef4 (diff) | |
download | numpy-ac78192390943d90ebae2f4e209e194914d0bc97.tar.gz |
Merge pull request #19879 from BvB93/cls_getitem
ENH: Add `__class_getitem__` to `ndarray`, `dtype` and `number`
-rw-r--r-- | doc/release/upcoming_changes/19879.new_feature.rst | 15 | ||||
-rw-r--r-- | doc/source/reference/arrays.dtypes.rst | 7 | ||||
-rw-r--r-- | doc/source/reference/arrays.ndarray.rst | 7 | ||||
-rw-r--r-- | doc/source/reference/arrays.scalars.rst | 11 | ||||
-rw-r--r-- | numpy/__init__.pyi | 12 | ||||
-rw-r--r-- | numpy/core/_add_newdocs.py | 98 | ||||
-rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 38 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 29 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 95 | ||||
-rw-r--r-- | numpy/core/tests/test_arraymethod.py | 36 | ||||
-rw-r--r-- | numpy/core/tests/test_dtype.py | 40 | ||||
-rw-r--r-- | numpy/core/tests/test_scalar_methods.py | 55 | ||||
-rw-r--r-- | numpy/typing/_generic_alias.py | 5 | ||||
-rw-r--r-- | tools/refguide_check.py | 43 |
14 files changed, 462 insertions, 29 deletions
diff --git a/doc/release/upcoming_changes/19879.new_feature.rst b/doc/release/upcoming_changes/19879.new_feature.rst new file mode 100644 index 000000000..c6624138b --- /dev/null +++ b/doc/release/upcoming_changes/19879.new_feature.rst @@ -0,0 +1,15 @@ +``ndarray``, ``dtype`` and ``number`` are now runtime-subscriptable +------------------------------------------------------------------- +Mimicking :pep:`585`, the `~numpy.ndarray`, `~numpy.dtype` and `~numpy.number` +classes are now subscriptable for python 3.9 and later. +Consequently, expressions that were previously only allowed in .pyi stub files +or with the help of ``from __future__ import annotations`` are now also legal +during runtime. + +.. code-block:: python + + >>> import numpy as np + >>> from typing import Any + + >>> np.ndarray[Any, np.dtype[np.float64]] + numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]] diff --git a/doc/source/reference/arrays.dtypes.rst b/doc/source/reference/arrays.dtypes.rst index b5ffa1a8b..34b0d7085 100644 --- a/doc/source/reference/arrays.dtypes.rst +++ b/doc/source/reference/arrays.dtypes.rst @@ -562,3 +562,10 @@ The following methods implement the pickle protocol: dtype.__reduce__ dtype.__setstate__ + +Utility method for typing: + +.. autosummary:: + :toctree: generated/ + + dtype.__class_getitem__ diff --git a/doc/source/reference/arrays.ndarray.rst b/doc/source/reference/arrays.ndarray.rst index f2204752d..7831b5f2c 100644 --- a/doc/source/reference/arrays.ndarray.rst +++ b/doc/source/reference/arrays.ndarray.rst @@ -621,3 +621,10 @@ String representations: ndarray.__str__ ndarray.__repr__ + +Utility method for typing: + +.. autosummary:: + :toctree: generated/ + + ndarray.__class_getitem__ diff --git a/doc/source/reference/arrays.scalars.rst b/doc/source/reference/arrays.scalars.rst index ccab0101e..c691e802f 100644 --- a/doc/source/reference/arrays.scalars.rst +++ b/doc/source/reference/arrays.scalars.rst @@ -196,10 +196,10 @@ Inexact types ``f16`` prints as ``0.1`` because it is as close to that value as possible, whereas the other types do not as they have more precision and therefore have closer values. - + Conversely, floating-point scalars of different precisions which approximate the same decimal value may compare unequal despite printing identically: - + >>> f16 = np.float16("0.1") >>> f32 = np.float32("0.1") >>> f64 = np.float64("0.1") @@ -498,6 +498,13 @@ The exceptions to the above rules are given below: generic.__setstate__ generic.setflags +Utility method for typing: + +.. autosummary:: + :toctree: generated/ + + number.__class_getitem__ + Defining new types ================== diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 09189a426..27d4a2ab0 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -9,6 +9,9 @@ from abc import abstractmethod from types import TracebackType, MappingProxyType from contextlib import ContextDecorator +if sys.version_info >= (3, 9): + from types import GenericAlias + from numpy._pytesttester import PytestTester from numpy.core.multiarray import flagsobj from numpy.core._internal import _ctypes @@ -1052,6 +1055,9 @@ class dtype(Generic[_DTypeScalar_co]): copy: bool = ..., ) -> dtype[object_]: ... + if sys.version_info >= (3, 9): + def __class_getitem__(self, item: Any) -> GenericAlias: ... + @overload def __getitem__(self: dtype[void], key: List[str]) -> dtype[void]: ... @overload @@ -1661,6 +1667,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): strides: None | _ShapeLike = ..., order: _OrderKACF = ..., ) -> _ArraySelf: ... + + if sys.version_info >= (3, 9): + def __class_getitem__(self, item: Any) -> GenericAlias: ... + @overload def __array__(self, dtype: None = ..., /) -> ndarray[Any, _DType_co]: ... @overload @@ -2850,6 +2860,8 @@ class number(generic, Generic[_NBit1]): # type: ignore def real(self: _ArraySelf) -> _ArraySelf: ... @property def imag(self: _ArraySelf) -> _ArraySelf: ... + if sys.version_info >= (3, 9): + def __class_getitem__(self, item: Any) -> GenericAlias: ... def __int__(self) -> int: ... def __float__(self) -> float: ... def __complex__(self) -> complex: ... diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py index 06f2a6376..6731b2e9d 100644 --- a/numpy/core/_add_newdocs.py +++ b/numpy/core/_add_newdocs.py @@ -796,7 +796,7 @@ add_newdoc('numpy.core.multiarray', 'array', object : array_like An array, any object exposing the array interface, an object whose __array__ method returns an array, or any (nested) sequence. - If object is a scalar, a 0-dimensional array containing object is + If object is a scalar, a 0-dimensional array containing object is returned. dtype : data-type, optional The desired data-type for the array. If not given, then the type will @@ -2201,8 +2201,8 @@ add_newdoc('numpy.core.multiarray', 'ndarray', empty : Create an array, but leave its allocated memory unchanged (i.e., it contains "garbage"). dtype : Create a data-type. - numpy.typing.NDArray : A :term:`generic <generic type>` version - of ndarray. + numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>` + w.r.t. its `dtype.type <numpy.dtype.type>`. Notes ----- @@ -2798,6 +2798,39 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__copy__', """)) +add_newdoc('numpy.core.multiarray', 'ndarray', ('__class_getitem__', + """a.__class_getitem__(item, /) + + Return a parametrized wrapper around the `~numpy.ndarray` type. + + .. versionadded:: 1.22 + + Returns + ------- + alias : types.GenericAlias + A parametrized `~numpy.ndarray` type. + + Examples + -------- + >>> from typing import Any + >>> import numpy as np + + >>> np.ndarray[Any, np.dtype[Any]] + numpy.ndarray[typing.Any, numpy.dtype[Any]] + + Note + ---- + This method is only available for python 3.9 and later. + + See Also + -------- + :pep:`585` : Type hinting generics in standard collections. + numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>` + w.r.t. its `dtype.type <numpy.dtype.type>`. + + """)) + + add_newdoc('numpy.core.multiarray', 'ndarray', ('__deepcopy__', """a.__deepcopy__(memo, /) -> Deep copy of array. @@ -6044,6 +6077,35 @@ add_newdoc('numpy.core.multiarray', 'dtype', ('newbyteorder', """)) +add_newdoc('numpy.core.multiarray', 'dtype', ('__class_getitem__', + """ + __class_getitem__(item, /) + + Return a parametrized wrapper around the `~numpy.dtype` type. + + .. versionadded:: 1.22 + + Returns + ------- + alias : types.GenericAlias + A parametrized `~numpy.dtype` type. + + Examples + -------- + >>> import numpy as np + + >>> np.dtype[np.int64] + numpy.dtype[numpy.int64] + + Note + ---- + This method is only available for python 3.9 and later. + + See Also + -------- + :pep:`585` : Type hinting generics in standard collections. + + """)) ############################################################################## # @@ -6465,6 +6527,36 @@ add_newdoc('numpy.core.numerictypes', 'generic', add_newdoc('numpy.core.numerictypes', 'generic', refer_to_array_attribute('view')) +add_newdoc('numpy.core.numerictypes', 'number', ('__class_getitem__', + """ + __class_getitem__(item, /) + + Return a parametrized wrapper around the `~numpy.number` type. + + .. versionadded:: 1.22 + + Returns + ------- + alias : types.GenericAlias + A parametrized `~numpy.number` type. + + Examples + -------- + >>> from typing import Any + >>> import numpy as np + + >>> np.signedinteger[Any] + numpy.signedinteger[typing.Any] + + Note + ---- + This method is only available for python 3.9 and later. + + See Also + -------- + :pep:`585` : Type hinting generics in standard collections. + + """)) ############################################################################## # diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 397768f19..082876aa2 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -257,7 +257,7 @@ static PyArray_Descr * _convert_from_tuple(PyObject *obj, int align) { if (PyTuple_GET_SIZE(obj) != 2) { - PyErr_Format(PyExc_TypeError, + PyErr_Format(PyExc_TypeError, "Tuple must have size 2, but has size %zd", PyTuple_GET_SIZE(obj)); return NULL; @@ -449,8 +449,8 @@ _convert_from_array_descr(PyObject *obj, int align) for (int i = 0; i < n; i++) { PyObject *item = PyList_GET_ITEM(obj, i); if (!PyTuple_Check(item) || (PyTuple_GET_SIZE(item) < 2)) { - PyErr_Format(PyExc_TypeError, - "Field elements must be 2- or 3-tuples, got '%R'", + PyErr_Format(PyExc_TypeError, + "Field elements must be 2- or 3-tuples, got '%R'", item); goto fail; } @@ -461,7 +461,7 @@ _convert_from_array_descr(PyObject *obj, int align) } else if (PyTuple_Check(name)) { if (PyTuple_GET_SIZE(name) != 2) { - PyErr_Format(PyExc_TypeError, + PyErr_Format(PyExc_TypeError, "If a tuple, the first element of a field tuple must have " "two elements, not %zd", PyTuple_GET_SIZE(name)); @@ -475,7 +475,7 @@ _convert_from_array_descr(PyObject *obj, int align) } } else { - PyErr_SetString(PyExc_TypeError, + PyErr_SetString(PyExc_TypeError, "First element of field tuple is " "neither a tuple nor str"); goto fail; @@ -3101,6 +3101,30 @@ arraydescr_newbyteorder(PyArray_Descr *self, PyObject *args) return (PyObject *)PyArray_DescrNewByteorder(self, endian); } +static PyObject * +arraydescr_class_getitem(PyObject *cls, PyObject *args) +{ + PyObject *generic_alias; + +#ifdef Py_GENERICALIASOBJECT_H + Py_ssize_t args_len; + + args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; + if (args_len != 1) { + return PyErr_Format(PyExc_TypeError, + "Too %s arguments for %s", + args_len > 1 ? "many" : "few", + ((PyTypeObject *)cls)->tp_name); + } + generic_alias = Py_GenericAlias(cls, args); +#else + PyErr_SetString(PyExc_TypeError, + "Type subscription requires python >= 3.9"); + generic_alias = NULL; +#endif + return generic_alias; +} + static PyMethodDef arraydescr_methods[] = { /* for pickling */ {"__reduce__", @@ -3112,6 +3136,10 @@ static PyMethodDef arraydescr_methods[] = { {"newbyteorder", (PyCFunction)arraydescr_newbyteorder, METH_VARARGS, NULL}, + /* for typing; requires python >= 3.9 */ + {"__class_getitem__", + (PyCFunction)arraydescr_class_getitem, + METH_CLASS | METH_O, NULL}, {NULL, NULL, 0, NULL} /* sentinel */ }; diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 2c10817fa..391e65f6a 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -2699,6 +2699,30 @@ array_complex(PyArrayObject *self, PyObject *NPY_UNUSED(args)) return c; } +static PyObject * +array_class_getitem(PyObject *cls, PyObject *args) +{ + PyObject *generic_alias; + +#ifdef Py_GENERICALIASOBJECT_H + Py_ssize_t args_len; + + args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; + if (args_len != 2) { + return PyErr_Format(PyExc_TypeError, + "Too %s arguments for %s", + args_len > 2 ? "many" : "few", + ((PyTypeObject *)cls)->tp_name); + } + generic_alias = Py_GenericAlias(cls, args); +#else + PyErr_SetString(PyExc_TypeError, + "Type subscription requires python >= 3.9"); + generic_alias = NULL; +#endif + return generic_alias; +} + NPY_NO_EXPORT PyMethodDef array_methods[] = { /* for subtypes */ @@ -2756,6 +2780,11 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { (PyCFunction) array_format, METH_VARARGS, NULL}, + /* for typing; requires python >= 3.9 */ + {"__class_getitem__", + (PyCFunction)array_class_getitem, + METH_CLASS | METH_O, NULL}, + /* Original and Extended methods added 2005 */ {"all", (PyCFunction)array_all, diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 4faa647ec..93cc9666e 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1805,6 +1805,59 @@ gentype_setflags(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), Py_RETURN_NONE; } +static PyObject * +numbertype_class_getitem_abc(PyObject *cls, PyObject *args) +{ + PyObject *generic_alias; + +#ifdef Py_GENERICALIASOBJECT_H + Py_ssize_t args_len; + int args_len_expected; + + /* complexfloating should take 2 parameters, all others take 1 */ + if (PyType_IsSubtype((PyTypeObject *)cls, + &PyComplexFloatingArrType_Type)) { + args_len_expected = 2; + } + else { + args_len_expected = 1; + } + + args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; + if (args_len != args_len_expected) { + return PyErr_Format(PyExc_TypeError, + "Too %s arguments for %s", + args_len > args_len_expected ? "many" : "few", + ((PyTypeObject *)cls)->tp_name); + } + generic_alias = Py_GenericAlias(cls, args); +#else + PyErr_SetString(PyExc_TypeError, + "Type subscription requires python >= 3.9"); + generic_alias = NULL; +#endif + return generic_alias; +} + +/* + * Use for concrete np.number subclasses, making them act as if they + * were subtyped from e.g. np.signedinteger[object], thus lacking any + * free subscription parameters. Requires python >= 3.9. + */ +static PyObject * +numbertype_class_getitem(PyObject *cls, PyObject *args) +{ +#ifdef Py_GENERICALIASOBJECT_H + PyErr_Format(PyExc_TypeError, + "There are no type variables left in %s", + ((PyTypeObject *)cls)->tp_name); +#else + PyErr_SetString(PyExc_TypeError, + "Type subscription requires python >= 3.9"); +#endif + return NULL; +} + /* * casting complex numbers (that don't inherit from Python complex) * to Python complex @@ -2188,6 +2241,14 @@ static PyGetSetDef inttype_getsets[] = { {NULL, NULL, NULL, NULL, NULL} }; +static PyMethodDef numbertype_methods[] = { + /* for typing; requires python >= 3.9 */ + {"__class_getitem__", + (PyCFunction)numbertype_class_getitem_abc, + METH_CLASS | METH_O, NULL}, + {NULL, NULL, 0, NULL} /* sentinel */ +}; + /**begin repeat * #name = cfloat,clongdouble# */ @@ -2195,6 +2256,10 @@ static PyMethodDef @name@type_methods[] = { {"__complex__", (PyCFunction)@name@_complex, METH_VARARGS | METH_KEYWORDS, NULL}, + /* for typing; requires python >= 3.9 */ + {"__class_getitem__", + (PyCFunction)numbertype_class_getitem, + METH_CLASS | METH_O, NULL}, {NULL, NULL, 0, NULL} }; /**end repeat**/ @@ -2232,6 +2297,23 @@ static PyMethodDef @name@type_methods[] = { {"is_integer", (PyCFunction)@name@_is_integer, METH_NOARGS, NULL}, + /* for typing; requires python >= 3.9 */ + {"__class_getitem__", + (PyCFunction)numbertype_class_getitem, + METH_CLASS | METH_O, NULL}, + {NULL, NULL, 0, NULL} +}; +/**end repeat**/ + +/**begin repeat + * #name = byte, short, int, long, longlong, ubyte, ushort, + * uint, ulong, ulonglong, timedelta, cdouble# + */ +static PyMethodDef @name@type_methods[] = { + /* for typing; requires python >= 3.9 */ + {"__class_getitem__", + (PyCFunction)numbertype_class_getitem, + METH_CLASS | METH_O, NULL}, {NULL, NULL, 0, NULL} }; /**end repeat**/ @@ -3951,6 +4033,8 @@ initialize_numeric_types(void) PyIntegerArrType_Type.tp_getset = inttype_getsets; + PyNumberArrType_Type.tp_methods = numbertype_methods; + /**begin repeat * #NAME= Number, Integer, SignedInteger, UnsignedInteger, Inexact, * Floating, ComplexFloating, Flexible, Character# @@ -4016,6 +4100,17 @@ initialize_numeric_types(void) /**end repeat**/ + /**begin repeat + * #name = byte, short, int, long, longlong, ubyte, ushort, + * uint, ulong, ulonglong, timedelta, cdouble# + * #Name = Byte, Short, Int, Long, LongLong, UByte, UShort, + * UInt, ULong, ULongLong, Timedelta, CDouble# + */ + + Py@Name@ArrType_Type.tp_methods = @name@type_methods; + + /**end repeat**/ + /* We won't be inheriting from Python Int type. */ PyIntArrType_Type.tp_hash = int_arrtype_hash; diff --git a/numpy/core/tests/test_arraymethod.py b/numpy/core/tests/test_arraymethod.py index b1bc79b80..49aa9f6df 100644 --- a/numpy/core/tests/test_arraymethod.py +++ b/numpy/core/tests/test_arraymethod.py @@ -3,6 +3,10 @@ This file tests the generic aspects of ArrayMethod. At the time of writing this is private API, but when added, public API may be added here. """ +import sys +import types +from typing import Any, Type + import pytest import numpy as np @@ -56,3 +60,35 @@ class TestSimpleStridedCall: # This is private API, which may be modified freely with pytest.raises(error): self.method._simple_strided_call(*args) + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") +class TestClassGetItem: + @pytest.mark.parametrize( + "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap] + ) + def test_class_getitem(self, cls: Type[np.ndarray]) -> None: + """Test `ndarray.__class_getitem__`.""" + alias = cls[Any, Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is cls + + @pytest.mark.parametrize("arg_len", range(4)) + def test_subscript_tuple(self, arg_len: int) -> None: + arg_tup = (Any,) * arg_len + if arg_len == 2: + assert np.ndarray[arg_tup] + else: + with pytest.raises(TypeError): + np.ndarray[arg_tup] + + def test_subscript_scalar(self) -> None: + with pytest.raises(TypeError): + np.ndarray[Any] + + +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") +def test_class_getitem_38() -> None: + match = "Type subscription requires python >= 3.9" + with pytest.raises(TypeError, match=match): + np.ndarray[Any, Any] diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 23269f01b..db4f275b5 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -4,6 +4,8 @@ import pytest import ctypes import gc import warnings +import types +from typing import Any import numpy as np from numpy.core._rational_tests import rational @@ -111,9 +113,9 @@ class TestBuiltin: @pytest.mark.parametrize("dtype", ['Bool', 'Bytes0', 'Complex32', 'Complex64', 'Datetime64', 'Float16', 'Float32', 'Float64', - 'Int8', 'Int16', 'Int32', 'Int64', + 'Int8', 'Int16', 'Int32', 'Int64', 'Object0', 'Str0', 'Timedelta64', - 'UInt8', 'UInt16', 'Uint32', 'UInt32', + 'UInt8', 'UInt16', 'Uint32', 'UInt32', 'Uint64', 'UInt64', 'Void0', "Float128", "Complex128"]) def test_numeric_style_types_are_invalid(self, dtype): @@ -1549,3 +1551,37 @@ class TestUserDType: # Tests that a dtype must have its type field set up to np.dtype # or in this case a builtin instance. create_custom_field_dtype(blueprint, mytype, 2) + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") +class TestClassGetItem: + def test_dtype(self) -> None: + alias = np.dtype[Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is np.dtype + + @pytest.mark.parametrize("code", np.typecodes["All"]) + def test_dtype_subclass(self, code: str) -> None: + cls = type(np.dtype(code)) + alias = cls[Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is cls + + @pytest.mark.parametrize("arg_len", range(4)) + def test_subscript_tuple(self, arg_len: int) -> None: + arg_tup = (Any,) * arg_len + if arg_len == 1: + assert np.dtype[arg_tup] + else: + with pytest.raises(TypeError): + np.dtype[arg_tup] + + def test_subscript_scalar(self) -> None: + assert np.dtype[Any] + + +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") +def test_class_getitem_38() -> None: + match = "Type subscription requires python >= 3.9" + with pytest.raises(TypeError, match=match): + np.dtype[Any] diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py index 94b2dd3c9..6077c8f75 100644 --- a/numpy/core/tests/test_scalar_methods.py +++ b/numpy/core/tests/test_scalar_methods.py @@ -1,8 +1,11 @@ """ Test the scalar constructors, which also do type-coercion """ +import sys import fractions import platform +import types +from typing import Any, Type import pytest import numpy as np @@ -128,3 +131,55 @@ class TestIsInteger: if value == 0: continue assert not value.is_integer() + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") +class TestClassGetItem: + @pytest.mark.parametrize("cls", [ + np.number, + np.integer, + np.inexact, + np.unsignedinteger, + np.signedinteger, + np.floating, + ]) + def test_abc(self, cls: Type[np.number]) -> None: + alias = cls[Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is cls + + def test_abc_complexfloating(self) -> None: + alias = np.complexfloating[Any, Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is np.complexfloating + + @pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character]) + def test_abc_non_numeric(self, cls: Type[np.generic]) -> None: + with pytest.raises(TypeError): + cls[Any] + + @pytest.mark.parametrize("code", np.typecodes["All"]) + def test_concrete(self, code: str) -> None: + cls = np.dtype(code).type + with pytest.raises(TypeError): + cls[Any] + + @pytest.mark.parametrize("arg_len", range(4)) + def test_subscript_tuple(self, arg_len: int) -> None: + arg_tup = (Any,) * arg_len + if arg_len == 1: + assert np.number[arg_tup] + else: + with pytest.raises(TypeError): + np.number[arg_tup] + + def test_subscript_scalar(self) -> None: + assert np.number[Any] + + +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") +@pytest.mark.parametrize("cls", [np.number, np.complexfloating, np.int64]) +def test_class_getitem_38(cls: Type[np.number]) -> None: + match = "Type subscription requires python >= 3.9" + with pytest.raises(TypeError, match=match): + cls[Any] diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py index 5ad5e580c..932f12dd0 100644 --- a/numpy/typing/_generic_alias.py +++ b/numpy/typing/_generic_alias.py @@ -205,12 +205,9 @@ else: ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True) -if TYPE_CHECKING: +if TYPE_CHECKING or sys.version_info >= (3, 9): _DType = np.dtype[ScalarType] NDArray = np.ndarray[Any, np.dtype[ScalarType]] -elif sys.version_info >= (3, 9): - _DType = types.GenericAlias(np.dtype, (ScalarType,)) - NDArray = types.GenericAlias(np.ndarray, (Any, _DType)) else: _DType = _GenericAlias(np.dtype, (ScalarType,)) NDArray = _GenericAlias(np.ndarray, (Any, _DType)) diff --git a/tools/refguide_check.py b/tools/refguide_check.py index a6bfc0fe4..21ba5a448 100644 --- a/tools/refguide_check.py +++ b/tools/refguide_check.py @@ -93,18 +93,27 @@ OTHER_MODULE_DOCS = { # these names are known to fail doctesting and we like to keep it that way # e.g. sometimes pseudocode is acceptable etc -DOCTEST_SKIPLIST = set([ +# +# Optionally, a subset of methods can be skipped by setting dict-values +# to a container of method-names +DOCTEST_SKIPDICT = { # cases where NumPy docstrings import things from SciPy: - 'numpy.lib.vectorize', - 'numpy.random.standard_gamma', - 'numpy.random.gamma', - 'numpy.random.vonmises', - 'numpy.random.power', - 'numpy.random.zipf', + 'numpy.lib.vectorize': None, + 'numpy.random.standard_gamma': None, + 'numpy.random.gamma': None, + 'numpy.random.vonmises': None, + 'numpy.random.power': None, + 'numpy.random.zipf': None, # remote / local file IO with DataSource is problematic in doctest: - 'numpy.lib.DataSource', - 'numpy.lib.Repository', -]) + 'numpy.lib.DataSource': None, + 'numpy.lib.Repository': None, +} +if sys.version_info < (3, 9): + DOCTEST_SKIPDICT.update({ + "numpy.core.ndarray": {"__class_getitem__"}, + "numpy.core.dtype": {"__class_getitem__"}, + "numpy.core.number": {"__class_getitem__"}, + }) # Skip non-numpy RST files, historical release notes # Any single-directory exact match will skip the directory and all subdirs. @@ -869,8 +878,12 @@ def check_doctests(module, verbose, ns=None, for name in get_all_dict(module)[0]: full_name = module.__name__ + '.' + name - if full_name in DOCTEST_SKIPLIST: - continue + if full_name in DOCTEST_SKIPDICT: + skip_methods = DOCTEST_SKIPDICT[full_name] + if skip_methods is None: + continue + else: + skip_methods = None try: obj = getattr(module, name) @@ -891,6 +904,10 @@ def check_doctests(module, verbose, ns=None, traceback.format_exc())) continue + if skip_methods is not None: + tests = [i for i in tests if + i.name.partition(".")[2] not in skip_methods] + success, output = _run_doctests(tests, full_name, verbose, doctest_warnings) @@ -971,7 +988,7 @@ def check_doctests_testfile(fname, verbose, ns=None, results = [] _, short_name = os.path.split(fname) - if short_name in DOCTEST_SKIPLIST: + if short_name in DOCTEST_SKIPDICT: return results full_name = fname |