diff options
| author | Charles Harris <charlesr.harris@gmail.com> | 2021-02-27 17:39:53 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-02-27 17:39:53 -0700 |
| commit | 1932e4bdc460975d8b35ae8351e037d26e5ee6f8 (patch) | |
| tree | 1e627ec579eef776185d4743570daa4c27a04def | |
| parent | cefab37bfba72ae062a58be3c7bdbd59c2c4d309 (diff) | |
| parent | f345d732a97b647de1fb26aeae533ca48e8229e9 (diff) | |
| download | numpy-1932e4bdc460975d8b35ae8351e037d26e5ee6f8.tar.gz | |
Merge pull request #18397 from BvB93/index-tricks
ENH: Add annotations for `np.lib.index_tricks`
| -rw-r--r-- | numpy/__init__.pyi | 57 | ||||
| -rw-r--r-- | numpy/lib/__init__.pyi | 34 | ||||
| -rw-r--r-- | numpy/lib/index_tricks.pyi | 179 | ||||
| -rw-r--r-- | numpy/typing/__init__.py | 2 | ||||
| -rw-r--r-- | numpy/typing/_array_like.py | 5 | ||||
| -rw-r--r-- | numpy/typing/tests/data/fail/index_tricks.py | 14 | ||||
| -rw-r--r-- | numpy/typing/tests/data/pass/index_tricks.py | 64 | ||||
| -rw-r--r-- | numpy/typing/tests/data/reveal/index_tricks.py | 63 |
8 files changed, 389 insertions, 29 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 148a63583..dba4176f2 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -338,6 +338,21 @@ from numpy.core.shape_base import ( vstack as vstack, ) +from numpy.lib.index_tricks import ( + ravel_multi_index as ravel_multi_index, + unravel_index as unravel_index, + mgrid as mgrid, + ogrid as ogrid, + r_ as r_, + c_ as c_, + s_ as s_, + index_exp as index_exp, + ix_ as ix_, + fill_diagonal as fill_diagonal, + diag_indices as diag_indices, + diag_indices_from as diag_indices_from, +) + from numpy.lib.ufunclike import ( fix as fix, isposinf as isposinf, @@ -375,7 +390,6 @@ busday_count: Any busday_offset: Any busdaycalendar: Any byte_bounds: Any -c_: Any can_cast: Any cast: Any chararray: Any @@ -395,8 +409,6 @@ delete: Any deprecate: Any deprecate_with_doc: Any diag: Any -diag_indices: Any -diag_indices_from: Any diagflat: Any diff: Any digitize: Any @@ -417,7 +429,6 @@ def eye( *, like: Optional[ArrayLike] = ... ) -> ndarray[Any, Any]: ... -fill_diagonal: Any finfo: Any flip: Any fliplr: Any @@ -444,7 +455,6 @@ i0: Any iinfo: Any imag: Any in1d: Any -index_exp: Any info: Any inner: Any insert: Any @@ -457,7 +467,6 @@ isin: Any isreal: Any isrealobj: Any iterable: Any -ix_: Any kaiser: Any kron: Any lexsort: Any @@ -474,7 +483,6 @@ may_share_memory: Any median: Any memmap: Any meshgrid: Any -mgrid: Any min: Any min_scalar_type: Any mintypecode: Any @@ -496,14 +504,11 @@ nanstd: Any nansum: Any nanvar: Any nbytes: Any -ndenumerate: Any ndfromtxt: Any -ndindex: Any nditer: Any nested_iters: Any newaxis: Any numarray: Any -ogrid: Any packbits: Any pad: Any percentile: Any @@ -524,8 +529,6 @@ promote_types: Any put_along_axis: Any putmask: Any quantile: Any -r_: Any -ravel_multi_index: Any real: Any real_if_close: Any recarray: Any @@ -538,7 +541,6 @@ rot90: Any round: Any round_: Any row_stack: Any -s_: Any save: Any savetxt: Any savez: Any @@ -570,7 +572,6 @@ typename: Any union1d: Any unique: Any unpackbits: Any -unravel_index: Any unwrap: Any vander: Any vdot: Any @@ -2899,3 +2900,31 @@ class errstate(Generic[_CallType], ContextDecorator): __exc_value: Optional[BaseException], __traceback: Optional[TracebackType], ) -> None: ... + +class ndenumerate(Generic[_ScalarType]): + iter: flatiter[_ArrayND[_ScalarType]] + @overload + def __new__( + cls, arr: _NestedSequence[_SupportsArray[dtype[_ScalarType]]], + ) -> ndenumerate[_ScalarType]: ... + @overload + def __new__(cls, arr: _NestedSequence[str]) -> ndenumerate[str_]: ... + @overload + def __new__(cls, arr: _NestedSequence[bytes]) -> ndenumerate[bytes_]: ... + @overload + def __new__(cls, arr: _NestedSequence[bool]) -> ndenumerate[bool_]: ... + @overload + def __new__(cls, arr: _NestedSequence[int]) -> ndenumerate[int_]: ... + @overload + def __new__(cls, arr: _NestedSequence[float]) -> ndenumerate[float_]: ... + @overload + def __new__(cls, arr: _NestedSequence[complex]) -> ndenumerate[complex_]: ... + @overload + def __new__(cls, arr: _RecursiveSequence) -> ndenumerate[Any]: ... + def __next__(self: ndenumerate[_ScalarType]) -> Tuple[_Shape, _ScalarType]: ... + def __iter__(self: _T) -> _T: ... + +class ndindex: + def __init__(self, *shape: SupportsIndex) -> None: ... + def __iter__(self: _T) -> _T: ... + def __next__(self) -> _Shape: ... diff --git a/numpy/lib/__init__.pyi b/numpy/lib/__init__.pyi index 4468d27e9..c7fab6943 100644 --- a/numpy/lib/__init__.pyi +++ b/numpy/lib/__init__.pyi @@ -1,5 +1,25 @@ from typing import Any, List +from numpy import ( + ndenumerate as ndenumerate, + ndindex as ndindex, +) + +from numpy.lib.index_tricks import ( + ravel_multi_index as ravel_multi_index, + unravel_index as unravel_index, + mgrid as mgrid, + ogrid as ogrid, + r_ as r_, + c_ as c_, + s_ as s_, + index_exp as index_exp, + ix_ as ix_, + fill_diagonal as fill_diagonal, + diag_indices as diag_indices, + diag_indices_from as diag_indices_from, +) + from numpy.lib.ufunclike import ( fix as fix, isposinf as isposinf, @@ -25,20 +45,6 @@ asfarray: Any mintypecode: Any asscalar: Any common_type: Any -ravel_multi_index: Any -unravel_index: Any -mgrid: Any -ogrid: Any -r_: Any -c_: Any -s_: Any -index_exp: Any -ix_: Any -ndenumerate: Any -ndindex: Any -fill_diagonal: Any -diag_indices: Any -diag_indices_from: Any select: Any piecewise: Any trim_zeros: Any diff --git a/numpy/lib/index_tricks.pyi b/numpy/lib/index_tricks.pyi new file mode 100644 index 000000000..3e5bc1adb --- /dev/null +++ b/numpy/lib/index_tricks.pyi @@ -0,0 +1,179 @@ +import sys +from typing import ( + Any, + Tuple, + TypeVar, + Generic, + overload, + List, + Union, + Sequence, +) + +from numpy import ( + # Circumvent a naming conflict with `AxisConcatenator.matrix` + matrix as _Matrix, + ndenumerate as ndenumerate, + ndindex as ndindex, + ndarray, + dtype, + str_, + bytes_, + bool_, + int_, + float_, + complex_, + intp, + _OrderCF, + _ModeKind, +) +from numpy.typing import ( + # Arrays + ArrayLike, + _NestedSequence, + _RecursiveSequence, + _ArrayND, + _ArrayOrScalar, + _ArrayLikeInt, + + # DTypes + DTypeLike, + _SupportsDType, + + # Shapes + _ShapeLike, +) + +if sys.version_info >= (3, 8): + from typing import Literal, SupportsIndex +else: + from typing_extensions import Literal, SupportsIndex + +_T = TypeVar("_T") +_DType = TypeVar("_DType", bound=dtype[Any]) +_BoolType = TypeVar("_BoolType", Literal[True], Literal[False]) +_TupType = TypeVar("_TupType", bound=Tuple[Any, ...]) +_ArrayType = TypeVar("_ArrayType", bound=ndarray[Any, Any]) + +__all__: List[str] + +def unravel_index( + indices: _ArrayLikeInt, + shape: _ShapeLike, + order: _OrderCF = ... +) -> Tuple[_ArrayOrScalar[intp], ...]: ... + +def ravel_multi_index( + multi_index: ArrayLike, + dims: _ShapeLike, + mode: Union[_ModeKind, Tuple[_ModeKind, ...]] = ..., + order: _OrderCF = ... +) -> _ArrayOrScalar[intp]: ... + +@overload +def ix_(*args: _NestedSequence[_SupportsDType[_DType]]) -> Tuple[ndarray[Any, _DType], ...]: ... +@overload +def ix_(*args: _NestedSequence[str]) -> Tuple[_ArrayND[str_], ...]: ... +@overload +def ix_(*args: _NestedSequence[bytes]) -> Tuple[_ArrayND[bytes_], ...]: ... +@overload +def ix_(*args: _NestedSequence[bool]) -> Tuple[_ArrayND[bool_], ...]: ... +@overload +def ix_(*args: _NestedSequence[int]) -> Tuple[_ArrayND[int_], ...]: ... +@overload +def ix_(*args: _NestedSequence[float]) -> Tuple[_ArrayND[float_], ...]: ... +@overload +def ix_(*args: _NestedSequence[complex]) -> Tuple[_ArrayND[complex_], ...]: ... +@overload +def ix_(*args: _RecursiveSequence) -> Tuple[_ArrayND[Any], ...]: ... + +class nd_grid(Generic[_BoolType]): + sparse: _BoolType + def __init__(self, sparse: _BoolType = ...) -> None: ... + @overload + def __getitem__( + self: nd_grid[Literal[False]], + key: Union[slice, Sequence[slice]], + ) -> _ArrayND[Any]: ... + @overload + def __getitem__( + self: nd_grid[Literal[True]], + key: Union[slice, Sequence[slice]], + ) -> List[_ArrayND[Any]]: ... + +class MGridClass(nd_grid[Literal[False]]): + def __init__(self) -> None: ... + +mgrid: MGridClass + +class OGridClass(nd_grid[Literal[True]]): + def __init__(self) -> None: ... + +ogrid: OGridClass + +class AxisConcatenator: + axis: int + matrix: bool + ndmin: int + trans1d: int + def __init__( + self, + axis: int = ..., + matrix: bool = ..., + ndmin: int = ..., + trans1d: int = ..., + ) -> None: ... + @staticmethod + @overload + def concatenate( # type: ignore[misc] + *a: ArrayLike, axis: SupportsIndex = ..., out: None = ... + ) -> _ArrayND[Any]: ... + @staticmethod + @overload + def concatenate( + *a: ArrayLike, axis: SupportsIndex = ..., out: _ArrayType = ... + ) -> _ArrayType: ... + @staticmethod + def makemat( + data: ArrayLike, dtype: DTypeLike = ..., copy: bool = ... + ) -> _Matrix: ... + + # TODO: Sort out this `__getitem__` method + def __getitem__(self, key: Any) -> Any: ... + +class RClass(AxisConcatenator): + axis: Literal[0] + matrix: Literal[False] + ndmin: Literal[1] + trans1d: Literal[-1] + def __init__(self) -> None: ... + +r_: RClass + +class CClass(AxisConcatenator): + axis: Literal[-1] + matrix: Literal[False] + ndmin: Literal[2] + trans1d: Literal[0] + def __init__(self) -> None: ... + +c_: CClass + +class IndexExpression(Generic[_BoolType]): + maketuple: _BoolType + def __init__(self, maketuple: _BoolType) -> None: ... + @overload + def __getitem__(self, item: _TupType) -> _TupType: ... # type: ignore[misc] + @overload + def __getitem__(self: IndexExpression[Literal[True]], item: _T) -> Tuple[_T]: ... + @overload + def __getitem__(self: IndexExpression[Literal[False]], item: _T) -> _T: ... + +index_exp: IndexExpression[Literal[True]] +s_: IndexExpression[Literal[False]] + +def fill_diagonal(a: ndarray[Any, Any], val: Any, wrap: bool = ...) -> None: ... +def diag_indices(n: int, ndim: int = ...) -> Tuple[_ArrayND[int_], ...]: ... +def diag_indices_from(arr: ArrayLike) -> Tuple[_ArrayND[int_], ...]: ... + +# NOTE: see `numpy/__init__.pyi` for `ndenumerate` and `ndindex` diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py index 61d780b85..d71ec0719 100644 --- a/numpy/typing/__init__.py +++ b/numpy/typing/__init__.py @@ -327,6 +327,7 @@ from ._array_like import ( _SupportsArray, _ArrayND, _ArrayOrScalar, + _ArrayLikeInt, _ArrayLikeBool_co, _ArrayLikeUInt_co, _ArrayLikeInt_co, @@ -339,7 +340,6 @@ from ._array_like import ( _ArrayLikeVoid_co, _ArrayLikeStr_co, _ArrayLikeBytes_co, - ) if __doc__ is not None: diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py index 133f38800..ef6c061d1 100644 --- a/numpy/typing/_array_like.py +++ b/numpy/typing/_array_like.py @@ -124,6 +124,11 @@ _ArrayLikeBytes_co = _ArrayLike[ bytes, ] +_ArrayLikeInt = _ArrayLike[ + "dtype[integer[Any]]", + int, +] + if TYPE_CHECKING: _ArrayND = ndarray[Any, dtype[_ScalarType]] _ArrayOrScalar = Union[_ScalarType, _ArrayND[_ScalarType]] diff --git a/numpy/typing/tests/data/fail/index_tricks.py b/numpy/typing/tests/data/fail/index_tricks.py new file mode 100644 index 000000000..cbc43fd54 --- /dev/null +++ b/numpy/typing/tests/data/fail/index_tricks.py @@ -0,0 +1,14 @@ +from typing import List +import numpy as np + +AR_LIKE_i: List[int] +AR_LIKE_f: List[float] + +np.unravel_index(AR_LIKE_f, (1, 2, 3)) # E: incompatible type +np.ravel_multi_index(AR_LIKE_i, (1, 2, 3), mode="bob") # E: incompatible type +np.mgrid[1] # E: Invalid index type +np.mgrid[...] # E: Invalid index type +np.ogrid[1] # E: Invalid index type +np.ogrid[...] # E: Invalid index type +np.fill_diagonal(AR_LIKE_f, 2) # E: incompatible type +np.diag_indices(1.0) # E: incompatible type diff --git a/numpy/typing/tests/data/pass/index_tricks.py b/numpy/typing/tests/data/pass/index_tricks.py new file mode 100644 index 000000000..4c4c11959 --- /dev/null +++ b/numpy/typing/tests/data/pass/index_tricks.py @@ -0,0 +1,64 @@ +from __future__ import annotations +from typing import Any +import numpy as np + +AR_LIKE_b = [[True, True], [True, True]] +AR_LIKE_i = [[1, 2], [3, 4]] +AR_LIKE_f = [[1.0, 2.0], [3.0, 4.0]] +AR_LIKE_U = [["1", "2"], ["3", "4"]] + +AR_i8: np.ndarray[Any, np.dtype[np.int64]] = np.array(AR_LIKE_i, dtype=np.int64) + +np.ndenumerate(AR_i8) +np.ndenumerate(AR_LIKE_f) +np.ndenumerate(AR_LIKE_U) + +np.ndenumerate(AR_i8).iter +np.ndenumerate(AR_LIKE_f).iter +np.ndenumerate(AR_LIKE_U).iter + +next(np.ndenumerate(AR_i8)) +next(np.ndenumerate(AR_LIKE_f)) +next(np.ndenumerate(AR_LIKE_U)) + +iter(np.ndenumerate(AR_i8)) +iter(np.ndenumerate(AR_LIKE_f)) +iter(np.ndenumerate(AR_LIKE_U)) + +iter(np.ndindex(1, 2, 3)) +next(np.ndindex(1, 2, 3)) + +np.unravel_index([22, 41, 37], (7, 6)) +np.unravel_index([31, 41, 13], (7, 6), order='F') +np.unravel_index(1621, (6, 7, 8, 9)) + +np.ravel_multi_index(AR_LIKE_i, (7, 6)) +np.ravel_multi_index(AR_LIKE_i, (7, 6), order='F') +np.ravel_multi_index(AR_LIKE_i, (4, 6), mode='clip') +np.ravel_multi_index(AR_LIKE_i, (4, 4), mode=('clip', 'wrap')) +np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9)) + +np.mgrid[1:1:2] +np.mgrid[1:1:2, None:10] + +np.ogrid[1:1:2] +np.ogrid[1:1:2, None:10] + +np.index_exp[0:1] +np.index_exp[0:1, None:3] +np.index_exp[0, 0:1, ..., [0, 1, 3]] + +np.s_[0:1] +np.s_[0:1, None:3] +np.s_[0, 0:1, ..., [0, 1, 3]] + +np.ix_(AR_LIKE_b[0]) +np.ix_(AR_LIKE_i[0], AR_LIKE_f[0]) +np.ix_(AR_i8[0]) + +np.fill_diagonal(AR_i8, 5) + +np.diag_indices(4) +np.diag_indices(2, 3) + +np.diag_indices_from(AR_i8) diff --git a/numpy/typing/tests/data/reveal/index_tricks.py b/numpy/typing/tests/data/reveal/index_tricks.py new file mode 100644 index 000000000..ec2013025 --- /dev/null +++ b/numpy/typing/tests/data/reveal/index_tricks.py @@ -0,0 +1,63 @@ +from typing import Any, List +import numpy as np + +AR_LIKE_b: List[bool] +AR_LIKE_i: List[int] +AR_LIKE_f: List[float] +AR_LIKE_U: List[str] + +AR_i8: np.ndarray[Any, np.dtype[np.int64]] + +reveal_type(np.ndenumerate(AR_i8)) # E: numpy.ndenumerate[{int64}] +reveal_type(np.ndenumerate(AR_LIKE_f)) # E: numpy.ndenumerate[{double}] +reveal_type(np.ndenumerate(AR_LIKE_U)) # E: numpy.ndenumerate[numpy.str_] + +reveal_type(np.ndenumerate(AR_i8).iter) # E: numpy.flatiter[numpy.ndarray[Any, numpy.dtype[{int64}]]] +reveal_type(np.ndenumerate(AR_LIKE_f).iter) # E: numpy.flatiter[numpy.ndarray[Any, numpy.dtype[{double}]]] +reveal_type(np.ndenumerate(AR_LIKE_U).iter) # E: numpy.flatiter[numpy.ndarray[Any, numpy.dtype[numpy.str_]]] + +reveal_type(next(np.ndenumerate(AR_i8))) # E: Tuple[builtins.tuple[builtins.int], {int64}] +reveal_type(next(np.ndenumerate(AR_LIKE_f))) # E: Tuple[builtins.tuple[builtins.int], {double}] +reveal_type(next(np.ndenumerate(AR_LIKE_U))) # E: Tuple[builtins.tuple[builtins.int], numpy.str_] + +reveal_type(iter(np.ndenumerate(AR_i8))) # E: Iterator[Tuple[builtins.tuple[builtins.int], {int64}]] +reveal_type(iter(np.ndenumerate(AR_LIKE_f))) # E: Iterator[Tuple[builtins.tuple[builtins.int], {double}]] +reveal_type(iter(np.ndenumerate(AR_LIKE_U))) # E: Iterator[Tuple[builtins.tuple[builtins.int], numpy.str_]] + +reveal_type(iter(np.ndindex(1, 2, 3))) # E: Iterator[builtins.tuple[builtins.int]] +reveal_type(next(np.ndindex(1, 2, 3))) # E: builtins.tuple[builtins.int] + +reveal_type(np.unravel_index([22, 41, 37], (7, 6))) # E: tuple[Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]] +reveal_type(np.unravel_index([31, 41, 13], (7, 6), order="F")) # E: tuple[Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]] +reveal_type(np.unravel_index(1621, (6, 7, 8, 9))) # E: tuple[Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]]] + +reveal_type(np.ravel_multi_index(AR_LIKE_i, (7, 6))) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]] +reveal_type(np.ravel_multi_index(AR_LIKE_i, (7, 6), order="F")) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]] +reveal_type(np.ravel_multi_index(AR_LIKE_i, (4, 6), mode="clip")) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]] +reveal_type(np.ravel_multi_index(AR_LIKE_i, (4, 4), mode=("clip", "wrap"))) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]] +reveal_type(np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9))) # E: Union[{intp}, numpy.ndarray[Any, numpy.dtype[{intp}]]] + +reveal_type(np.mgrid[1:1:2]) # E: numpy.ndarray[Any, numpy.dtype[Any]] +reveal_type(np.mgrid[1:1:2, None:10]) # E: numpy.ndarray[Any, numpy.dtype[Any]] + +reveal_type(np.ogrid[1:1:2]) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]] +reveal_type(np.ogrid[1:1:2, None:10]) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]] + +reveal_type(np.index_exp[0:1]) # E: Tuple[builtins.slice] +reveal_type(np.index_exp[0:1, None:3]) # E: Tuple[builtins.slice, builtins.slice] +reveal_type(np.index_exp[0, 0:1, ..., [0, 1, 3]]) # E: Tuple[Literal[0]?, builtins.slice, builtins.ellipsis, builtins.list[builtins.int]] + +reveal_type(np.s_[0:1]) # E: builtins.slice +reveal_type(np.s_[0:1, None:3]) # E: Tuple[builtins.slice, builtins.slice] +reveal_type(np.s_[0, 0:1, ..., [0, 1, 3]]) # E: Tuple[Literal[0]?, builtins.slice, builtins.ellipsis, builtins.list[builtins.int]] + +reveal_type(np.ix_(AR_LIKE_b)) # E: tuple[numpy.ndarray[Any, numpy.dtype[numpy.bool_]]] +reveal_type(np.ix_(AR_LIKE_i, AR_LIKE_f)) # E: tuple[numpy.ndarray[Any, numpy.dtype[{double}]]] +reveal_type(np.ix_(AR_i8)) # E: tuple[numpy.ndarray[Any, numpy.dtype[{int64}]]] + +reveal_type(np.fill_diagonal(AR_i8, 5)) # E: None + +reveal_type(np.diag_indices(4)) # E: tuple[numpy.ndarray[Any, numpy.dtype[{int_}]]] +reveal_type(np.diag_indices(2, 3)) # E: tuple[numpy.ndarray[Any, numpy.dtype[{int_}]]] + +reveal_type(np.diag_indices_from(AR_i8)) # E: tuple[numpy.ndarray[Any, numpy.dtype[{int_}]]] |
