diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2020-12-11 12:34:49 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-11 12:34:49 -0700 |
commit | cfb6a4d88d88307d507a86d1d70cd7d84c611406 (patch) | |
tree | 10f337f6fb37ac09c590040e7fcea20c8beb8da1 | |
parent | ba596df4cc23445a117c3785239a75a3cab6b2f2 (diff) | |
parent | 6139ed42af271d37234320e83a41000d93bdeae1 (diff) | |
download | numpy-cfb6a4d88d88307d507a86d1d70cd7d84c611406.tar.gz |
Merge pull request #17981 from BvB93/flatiter
ENH: Add proper dtype-support to `np.flatiter`
-rw-r--r-- | numpy/__init__.pyi | 33 | ||||
-rw-r--r-- | numpy/typing/_array_like.py | 21 | ||||
-rw-r--r-- | numpy/typing/tests/data/fail/array_like.py | 4 | ||||
-rw-r--r-- | numpy/typing/tests/data/pass/array_like.py | 8 | ||||
-rw-r--r-- | numpy/typing/tests/data/pass/flatiter.py | 2 | ||||
-rw-r--r-- | numpy/typing/tests/data/reveal/flatiter.py | 21 |
6 files changed, 56 insertions, 33 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index f2c414c0b..cf9b3e86a 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -955,24 +955,30 @@ _ArrayLikeInt = Union[ _FlatIterSelf = TypeVar("_FlatIterSelf", bound=flatiter) -class flatiter(Generic[_ArraySelf]): +class flatiter(Generic[_NdArraySubClass]): @property - def base(self) -> _ArraySelf: ... + def base(self) -> _NdArraySubClass: ... @property def coords(self) -> _Shape: ... @property def index(self) -> int: ... - def copy(self) -> _ArraySelf: ... + def copy(self) -> _NdArraySubClass: ... def __iter__(self: _FlatIterSelf) -> _FlatIterSelf: ... - def __next__(self) -> generic: ... + def __next__(self: flatiter[ndarray[Any, dtype[_ScalarType]]]) -> _ScalarType: ... def __len__(self) -> int: ... @overload - def __getitem__(self, key: Union[int, integer]) -> generic: ... + def __getitem__( + self: flatiter[ndarray[Any, dtype[_ScalarType]]], + key: Union[int, integer], + ) -> _ScalarType: ... @overload def __getitem__( self, key: Union[_ArrayLikeInt, slice, ellipsis], - ) -> _ArraySelf: ... - def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ... + ) -> _NdArraySubClass: ... + @overload + def __array__(self: flatiter[ndarray[Any, _DType]], __dtype: None = ...) -> ndarray[Any, _DType]: ... + @overload + def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ... _OrderKACF = Optional[Literal["K", "A", "C", "F"]] _OrderACF = Optional[Literal["A", "C", "F"]] @@ -1004,7 +1010,6 @@ class _ArrayOrScalarCommon: def itemsize(self) -> int: ... @property def nbytes(self) -> int: ... - def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ... def __bool__(self) -> bool: ... def __bytes__(self) -> bytes: ... def __str__(self) -> str: ... @@ -1468,6 +1473,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]): strides: _ShapeLike = ..., order: _OrderKACF = ..., ) -> _ArraySelf: ... + @overload + def __array__(self, __dtype: None = ...) -> ndarray[Any, _DType]: ... + @overload + def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ... @property def ctypes(self) -> _ctypes: ... @property @@ -1481,7 +1490,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]): def byteswap(self: _ArraySelf, inplace: bool = ...) -> _ArraySelf: ... def fill(self, value: Any) -> None: ... @property - def flat(self: _ArraySelf) -> flatiter[_ArraySelf]: ... + def flat(self: _NdArraySubClass) -> flatiter[_NdArraySubClass]: ... @overload def item(self, *args: int) -> Any: ... @overload @@ -1646,6 +1655,10 @@ _NBit_co2 = TypeVar("_NBit_co2", covariant=True, bound=NBitBase) class generic(_ArrayOrScalarCommon): @abstractmethod def __init__(self, *args: Any, **kwargs: Any) -> None: ... + @overload + def __array__(self: _ScalarType, __dtype: None = ...) -> ndarray[Any, dtype[_ScalarType]]: ... + @overload + def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ... @property def base(self) -> None: ... @property @@ -1658,7 +1671,7 @@ class generic(_ArrayOrScalarCommon): def strides(self) -> Tuple[()]: ... def byteswap(self: _ScalarType, inplace: Literal[False] = ...) -> _ScalarType: ... @property - def flat(self) -> flatiter[ndarray]: ... + def flat(self: _ScalarType) -> flatiter[ndarray[Any, dtype[_ScalarType]]]: ... def item( self: _ScalarType, __args: Union[Literal[0], Tuple[()], Tuple[Literal[0]]] = ..., diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py index a1a604239..63b67b33a 100644 --- a/numpy/typing/_array_like.py +++ b/numpy/typing/_array_like.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import sys -from typing import Any, overload, Sequence, TYPE_CHECKING, Union +from typing import Any, overload, Sequence, TYPE_CHECKING, Union, TypeVar -from numpy import ndarray +from numpy import ndarray, dtype from ._scalars import _ScalarLike from ._dtype_like import DTypeLike @@ -16,12 +18,15 @@ else: else: HAVE_PROTOCOL = True +_DType = TypeVar("_DType", bound="dtype[Any]") + if TYPE_CHECKING or HAVE_PROTOCOL: - class _SupportsArray(Protocol): - @overload - def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ... - @overload - def __array__(self, dtype: DTypeLike = ...) -> ndarray: ... + # The `_SupportsArray` protocol only cares about the default dtype + # (i.e. `dtype=None`) of the to-be returned array. + # Concrete implementations of the protocol are responsible for adding + # any and all remaining overloads + class _SupportsArray(Protocol[_DType]): + def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ... else: _SupportsArray = Any @@ -36,5 +41,5 @@ ArrayLike = Union[ _ScalarLike, Sequence[_ScalarLike], Sequence[Sequence[Any]], # TODO: Wait for support for recursive types - _SupportsArray, + "_SupportsArray[Any]", ] diff --git a/numpy/typing/tests/data/fail/array_like.py b/numpy/typing/tests/data/fail/array_like.py index a97e72dc7..3bbd29061 100644 --- a/numpy/typing/tests/data/fail/array_like.py +++ b/numpy/typing/tests/data/fail/array_like.py @@ -11,6 +11,6 @@ x2: ArrayLike = A() # E: Incompatible types in assignment x3: ArrayLike = {1: "foo", 2: "bar"} # E: Incompatible types in assignment scalar = np.int64(1) -scalar.__array__(dtype=np.float64) # E: Unexpected keyword argument +scalar.__array__(dtype=np.float64) # E: No overload variant array = np.array([1]) -array.__array__(dtype=np.float64) # E: Unexpected keyword argument +array.__array__(dtype=np.float64) # E: No overload variant diff --git a/numpy/typing/tests/data/pass/array_like.py b/numpy/typing/tests/data/pass/array_like.py index f85724267..563fc08c7 100644 --- a/numpy/typing/tests/data/pass/array_like.py +++ b/numpy/typing/tests/data/pass/array_like.py @@ -25,13 +25,13 @@ class A: x13: ArrayLike = A() scalar: _SupportsArray = np.int64(1) -scalar.__array__(np.float64) +scalar.__array__(None) array: _SupportsArray = np.array(1) -array.__array__(np.float64) +array.__array__(None) a: _SupportsArray = A() -a.__array__(np.int64) -a.__array__(dtype=np.int64) +a.__array__(None) +a.__array__(dtype=None) # Escape hatch for when you mean to make something like an object # array. diff --git a/numpy/typing/tests/data/pass/flatiter.py b/numpy/typing/tests/data/pass/flatiter.py index c0219eb2b..4fdf25299 100644 --- a/numpy/typing/tests/data/pass/flatiter.py +++ b/numpy/typing/tests/data/pass/flatiter.py @@ -12,3 +12,5 @@ a[0] a[[0, 1, 2]] a[...] a[:] +a.__array__() +a.__array__(np.float64) diff --git a/numpy/typing/tests/data/reveal/flatiter.py b/numpy/typing/tests/data/reveal/flatiter.py index 56cdc7a0e..221101ebb 100644 --- a/numpy/typing/tests/data/reveal/flatiter.py +++ b/numpy/typing/tests/data/reveal/flatiter.py @@ -1,14 +1,17 @@ +from typing import Any import numpy as np -a: "np.flatiter[np.ndarray]" +a: np.flatiter[np.ndarray[Any, np.dtype[np.str_]]] -reveal_type(a.base) # E: numpy.ndarray* -reveal_type(a.copy()) # E: numpy.ndarray* +reveal_type(a.base) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]] +reveal_type(a.copy()) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]] reveal_type(a.coords) # E: tuple[builtins.int] reveal_type(a.index) # E: int -reveal_type(iter(a)) # E: Iterator[numpy.generic*] -reveal_type(next(a)) # E: numpy.generic -reveal_type(a[0]) # E: numpy.generic -reveal_type(a[[0, 1, 2]]) # E: numpy.ndarray* -reveal_type(a[...]) # E: numpy.ndarray* -reveal_type(a[:]) # E: numpy.ndarray* +reveal_type(iter(a)) # E: Iterator[numpy.str_] +reveal_type(next(a)) # E: numpy.str_ +reveal_type(a[0]) # E: numpy.str_ +reveal_type(a[[0, 1, 2]]) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]] +reveal_type(a[...]) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]] +reveal_type(a[:]) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]] +reveal_type(a.__array__()) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]] +reveal_type(a.__array__(np.float64)) # E: numpy.ndarray[Any, numpy.dtype[Any]] |