summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2020-12-11 12:34:49 -0700
committerGitHub <noreply@github.com>2020-12-11 12:34:49 -0700
commitcfb6a4d88d88307d507a86d1d70cd7d84c611406 (patch)
tree10f337f6fb37ac09c590040e7fcea20c8beb8da1
parentba596df4cc23445a117c3785239a75a3cab6b2f2 (diff)
parent6139ed42af271d37234320e83a41000d93bdeae1 (diff)
downloadnumpy-cfb6a4d88d88307d507a86d1d70cd7d84c611406.tar.gz
Merge pull request #17981 from BvB93/flatiter
ENH: Add proper dtype-support to `np.flatiter`
-rw-r--r--numpy/__init__.pyi33
-rw-r--r--numpy/typing/_array_like.py21
-rw-r--r--numpy/typing/tests/data/fail/array_like.py4
-rw-r--r--numpy/typing/tests/data/pass/array_like.py8
-rw-r--r--numpy/typing/tests/data/pass/flatiter.py2
-rw-r--r--numpy/typing/tests/data/reveal/flatiter.py21
6 files changed, 56 insertions, 33 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index f2c414c0b..cf9b3e86a 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -955,24 +955,30 @@ _ArrayLikeInt = Union[
_FlatIterSelf = TypeVar("_FlatIterSelf", bound=flatiter)
-class flatiter(Generic[_ArraySelf]):
+class flatiter(Generic[_NdArraySubClass]):
@property
- def base(self) -> _ArraySelf: ...
+ def base(self) -> _NdArraySubClass: ...
@property
def coords(self) -> _Shape: ...
@property
def index(self) -> int: ...
- def copy(self) -> _ArraySelf: ...
+ def copy(self) -> _NdArraySubClass: ...
def __iter__(self: _FlatIterSelf) -> _FlatIterSelf: ...
- def __next__(self) -> generic: ...
+ def __next__(self: flatiter[ndarray[Any, dtype[_ScalarType]]]) -> _ScalarType: ...
def __len__(self) -> int: ...
@overload
- def __getitem__(self, key: Union[int, integer]) -> generic: ...
+ def __getitem__(
+ self: flatiter[ndarray[Any, dtype[_ScalarType]]],
+ key: Union[int, integer],
+ ) -> _ScalarType: ...
@overload
def __getitem__(
self, key: Union[_ArrayLikeInt, slice, ellipsis],
- ) -> _ArraySelf: ...
- def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ...
+ ) -> _NdArraySubClass: ...
+ @overload
+ def __array__(self: flatiter[ndarray[Any, _DType]], __dtype: None = ...) -> ndarray[Any, _DType]: ...
+ @overload
+ def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ...
_OrderKACF = Optional[Literal["K", "A", "C", "F"]]
_OrderACF = Optional[Literal["A", "C", "F"]]
@@ -1004,7 +1010,6 @@ class _ArrayOrScalarCommon:
def itemsize(self) -> int: ...
@property
def nbytes(self) -> int: ...
- def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ...
def __bool__(self) -> bool: ...
def __bytes__(self) -> bytes: ...
def __str__(self) -> str: ...
@@ -1468,6 +1473,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]):
strides: _ShapeLike = ...,
order: _OrderKACF = ...,
) -> _ArraySelf: ...
+ @overload
+ def __array__(self, __dtype: None = ...) -> ndarray[Any, _DType]: ...
+ @overload
+ def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ...
@property
def ctypes(self) -> _ctypes: ...
@property
@@ -1481,7 +1490,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType]):
def byteswap(self: _ArraySelf, inplace: bool = ...) -> _ArraySelf: ...
def fill(self, value: Any) -> None: ...
@property
- def flat(self: _ArraySelf) -> flatiter[_ArraySelf]: ...
+ def flat(self: _NdArraySubClass) -> flatiter[_NdArraySubClass]: ...
@overload
def item(self, *args: int) -> Any: ...
@overload
@@ -1646,6 +1655,10 @@ _NBit_co2 = TypeVar("_NBit_co2", covariant=True, bound=NBitBase)
class generic(_ArrayOrScalarCommon):
@abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
+ @overload
+ def __array__(self: _ScalarType, __dtype: None = ...) -> ndarray[Any, dtype[_ScalarType]]: ...
+ @overload
+ def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ...
@property
def base(self) -> None: ...
@property
@@ -1658,7 +1671,7 @@ class generic(_ArrayOrScalarCommon):
def strides(self) -> Tuple[()]: ...
def byteswap(self: _ScalarType, inplace: Literal[False] = ...) -> _ScalarType: ...
@property
- def flat(self) -> flatiter[ndarray]: ...
+ def flat(self: _ScalarType) -> flatiter[ndarray[Any, dtype[_ScalarType]]]: ...
def item(
self: _ScalarType,
__args: Union[Literal[0], Tuple[()], Tuple[Literal[0]]] = ...,
diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py
index a1a604239..63b67b33a 100644
--- a/numpy/typing/_array_like.py
+++ b/numpy/typing/_array_like.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import sys
-from typing import Any, overload, Sequence, TYPE_CHECKING, Union
+from typing import Any, overload, Sequence, TYPE_CHECKING, Union, TypeVar
-from numpy import ndarray
+from numpy import ndarray, dtype
from ._scalars import _ScalarLike
from ._dtype_like import DTypeLike
@@ -16,12 +18,15 @@ else:
else:
HAVE_PROTOCOL = True
+_DType = TypeVar("_DType", bound="dtype[Any]")
+
if TYPE_CHECKING or HAVE_PROTOCOL:
- class _SupportsArray(Protocol):
- @overload
- def __array__(self, __dtype: DTypeLike = ...) -> ndarray: ...
- @overload
- def __array__(self, dtype: DTypeLike = ...) -> ndarray: ...
+ # The `_SupportsArray` protocol only cares about the default dtype
+ # (i.e. `dtype=None`) of the to-be returned array.
+ # Concrete implementations of the protocol are responsible for adding
+ # any and all remaining overloads
+ class _SupportsArray(Protocol[_DType]):
+ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
else:
_SupportsArray = Any
@@ -36,5 +41,5 @@ ArrayLike = Union[
_ScalarLike,
Sequence[_ScalarLike],
Sequence[Sequence[Any]], # TODO: Wait for support for recursive types
- _SupportsArray,
+ "_SupportsArray[Any]",
]
diff --git a/numpy/typing/tests/data/fail/array_like.py b/numpy/typing/tests/data/fail/array_like.py
index a97e72dc7..3bbd29061 100644
--- a/numpy/typing/tests/data/fail/array_like.py
+++ b/numpy/typing/tests/data/fail/array_like.py
@@ -11,6 +11,6 @@ x2: ArrayLike = A() # E: Incompatible types in assignment
x3: ArrayLike = {1: "foo", 2: "bar"} # E: Incompatible types in assignment
scalar = np.int64(1)
-scalar.__array__(dtype=np.float64) # E: Unexpected keyword argument
+scalar.__array__(dtype=np.float64) # E: No overload variant
array = np.array([1])
-array.__array__(dtype=np.float64) # E: Unexpected keyword argument
+array.__array__(dtype=np.float64) # E: No overload variant
diff --git a/numpy/typing/tests/data/pass/array_like.py b/numpy/typing/tests/data/pass/array_like.py
index f85724267..563fc08c7 100644
--- a/numpy/typing/tests/data/pass/array_like.py
+++ b/numpy/typing/tests/data/pass/array_like.py
@@ -25,13 +25,13 @@ class A:
x13: ArrayLike = A()
scalar: _SupportsArray = np.int64(1)
-scalar.__array__(np.float64)
+scalar.__array__(None)
array: _SupportsArray = np.array(1)
-array.__array__(np.float64)
+array.__array__(None)
a: _SupportsArray = A()
-a.__array__(np.int64)
-a.__array__(dtype=np.int64)
+a.__array__(None)
+a.__array__(dtype=None)
# Escape hatch for when you mean to make something like an object
# array.
diff --git a/numpy/typing/tests/data/pass/flatiter.py b/numpy/typing/tests/data/pass/flatiter.py
index c0219eb2b..4fdf25299 100644
--- a/numpy/typing/tests/data/pass/flatiter.py
+++ b/numpy/typing/tests/data/pass/flatiter.py
@@ -12,3 +12,5 @@ a[0]
a[[0, 1, 2]]
a[...]
a[:]
+a.__array__()
+a.__array__(np.float64)
diff --git a/numpy/typing/tests/data/reveal/flatiter.py b/numpy/typing/tests/data/reveal/flatiter.py
index 56cdc7a0e..221101ebb 100644
--- a/numpy/typing/tests/data/reveal/flatiter.py
+++ b/numpy/typing/tests/data/reveal/flatiter.py
@@ -1,14 +1,17 @@
+from typing import Any
import numpy as np
-a: "np.flatiter[np.ndarray]"
+a: np.flatiter[np.ndarray[Any, np.dtype[np.str_]]]
-reveal_type(a.base) # E: numpy.ndarray*
-reveal_type(a.copy()) # E: numpy.ndarray*
+reveal_type(a.base) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
+reveal_type(a.copy()) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
reveal_type(a.coords) # E: tuple[builtins.int]
reveal_type(a.index) # E: int
-reveal_type(iter(a)) # E: Iterator[numpy.generic*]
-reveal_type(next(a)) # E: numpy.generic
-reveal_type(a[0]) # E: numpy.generic
-reveal_type(a[[0, 1, 2]]) # E: numpy.ndarray*
-reveal_type(a[...]) # E: numpy.ndarray*
-reveal_type(a[:]) # E: numpy.ndarray*
+reveal_type(iter(a)) # E: Iterator[numpy.str_]
+reveal_type(next(a)) # E: numpy.str_
+reveal_type(a[0]) # E: numpy.str_
+reveal_type(a[[0, 1, 2]]) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
+reveal_type(a[...]) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
+reveal_type(a[:]) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
+reveal_type(a.__array__()) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
+reveal_type(a.__array__(np.float64)) # E: numpy.ndarray[Any, numpy.dtype[Any]]