summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-02-16 14:51:49 +0100
committerBas van Beek <b.f.van.beek@vu.nl>2021-02-25 14:06:37 +0100
commitf345d732a97b647de1fb26aeae533ca48e8229e9 (patch)
tree8fa40fd5427def493badf6495789b4b671db2d6b
parent8a18515f6b5eb97a184b92858741eb4a5aab613a (diff)
downloadnumpy-f345d732a97b647de1fb26aeae533ca48e8229e9.tar.gz
MAINT: Relax the type-constraints of `IndexExpression.__getitem__`
-rw-r--r--numpy/lib/index_tricks.pyi10
-rw-r--r--numpy/typing/tests/data/fail/index_tricks.py4
-rw-r--r--numpy/typing/tests/data/pass/index_tricks.py2
-rw-r--r--numpy/typing/tests/data/reveal/index_tricks.py2
4 files changed, 9 insertions, 9 deletions
diff --git a/numpy/lib/index_tricks.pyi b/numpy/lib/index_tricks.pyi
index e602f9907..3e5bc1adb 100644
--- a/numpy/lib/index_tricks.pyi
+++ b/numpy/lib/index_tricks.pyi
@@ -52,7 +52,7 @@ else:
_T = TypeVar("_T")
_DType = TypeVar("_DType", bound=dtype[Any])
_BoolType = TypeVar("_BoolType", Literal[True], Literal[False])
-_SliceOrTuple = TypeVar("_SliceOrTuple", bound=Union[slice, Tuple[slice, ...]])
+_TupType = TypeVar("_TupType", bound=Tuple[Any, ...])
_ArrayType = TypeVar("_ArrayType", bound=ndarray[Any, Any])
__all__: List[str]
@@ -163,11 +163,11 @@ class IndexExpression(Generic[_BoolType]):
maketuple: _BoolType
def __init__(self, maketuple: _BoolType) -> None: ...
@overload
- def __getitem__( # type: ignore[misc]
- self: IndexExpression[Literal[True]], item: slice
- ) -> Tuple[slice]: ...
+ def __getitem__(self, item: _TupType) -> _TupType: ... # type: ignore[misc]
@overload
- def __getitem__(self, item: _SliceOrTuple) -> _SliceOrTuple: ...
+ def __getitem__(self: IndexExpression[Literal[True]], item: _T) -> Tuple[_T]: ...
+ @overload
+ def __getitem__(self: IndexExpression[Literal[False]], item: _T) -> _T: ...
index_exp: IndexExpression[Literal[True]]
s_: IndexExpression[Literal[False]]
diff --git a/numpy/typing/tests/data/fail/index_tricks.py b/numpy/typing/tests/data/fail/index_tricks.py
index 706f135b2..cbc43fd54 100644
--- a/numpy/typing/tests/data/fail/index_tricks.py
+++ b/numpy/typing/tests/data/fail/index_tricks.py
@@ -10,9 +10,5 @@ np.mgrid[1] # E: Invalid index type
np.mgrid[...] # E: Invalid index type
np.ogrid[1] # E: Invalid index type
np.ogrid[...] # E: Invalid index type
-np.index_exp[1] # E: No overload variant
-np.index_exp[...] # E: No overload variant
-np.s_[1] # E: cannot be "int"
-np.s_[...] # E: cannot be "ellipsis"
np.fill_diagonal(AR_LIKE_f, 2) # E: incompatible type
np.diag_indices(1.0) # E: incompatible type
diff --git a/numpy/typing/tests/data/pass/index_tricks.py b/numpy/typing/tests/data/pass/index_tricks.py
index ce7f415f3..4c4c11959 100644
--- a/numpy/typing/tests/data/pass/index_tricks.py
+++ b/numpy/typing/tests/data/pass/index_tricks.py
@@ -46,9 +46,11 @@ np.ogrid[1:1:2, None:10]
np.index_exp[0:1]
np.index_exp[0:1, None:3]
+np.index_exp[0, 0:1, ..., [0, 1, 3]]
np.s_[0:1]
np.s_[0:1, None:3]
+np.s_[0, 0:1, ..., [0, 1, 3]]
np.ix_(AR_LIKE_b[0])
np.ix_(AR_LIKE_i[0], AR_LIKE_f[0])
diff --git a/numpy/typing/tests/data/reveal/index_tricks.py b/numpy/typing/tests/data/reveal/index_tricks.py
index dc061d314..ec2013025 100644
--- a/numpy/typing/tests/data/reveal/index_tricks.py
+++ b/numpy/typing/tests/data/reveal/index_tricks.py
@@ -45,9 +45,11 @@ reveal_type(np.ogrid[1:1:2, None:10]) # E: list[numpy.ndarray[Any, numpy.dtype[
reveal_type(np.index_exp[0:1]) # E: Tuple[builtins.slice]
reveal_type(np.index_exp[0:1, None:3]) # E: Tuple[builtins.slice, builtins.slice]
+reveal_type(np.index_exp[0, 0:1, ..., [0, 1, 3]]) # E: Tuple[Literal[0]?, builtins.slice, builtins.ellipsis, builtins.list[builtins.int]]
reveal_type(np.s_[0:1]) # E: builtins.slice
reveal_type(np.s_[0:1, None:3]) # E: Tuple[builtins.slice, builtins.slice]
+reveal_type(np.s_[0, 0:1, ..., [0, 1, 3]]) # E: Tuple[Literal[0]?, builtins.slice, builtins.ellipsis, builtins.list[builtins.int]]
reveal_type(np.ix_(AR_LIKE_b)) # E: tuple[numpy.ndarray[Any, numpy.dtype[numpy.bool_]]]
reveal_type(np.ix_(AR_LIKE_i, AR_LIKE_f)) # E: tuple[numpy.ndarray[Any, numpy.dtype[{double}]]]