diff options
author | Bas van Beek <bas.vanbeek@hotmail.com> | 2023-05-17 21:16:35 +0200 |
---|---|---|
committer | Bas van Beek <bas.vanbeek@hotmail.com> | 2023-05-17 21:16:35 +0200 |
commit | a4c249653ec9d063a67a6cde8123dca2defb8f8b (patch) | |
tree | 659a1e5d15cf070ca3954ba94108adddc114ab43 | |
parent | 610f85c2ee769853f1d2c291476b49face35c691 (diff) | |
download | numpy-a4c249653ec9d063a67a6cde8123dca2defb8f8b.tar.gz |
TYP: Update type annotations for the new linalg named tuples
-rw-r--r-- | numpy/linalg/linalg.py | 22 | ||||
-rw-r--r-- | numpy/linalg/linalg.pyi | 65 | ||||
-rw-r--r-- | numpy/typing/tests/data/reveal/linalg.pyi | 30 |
3 files changed, 66 insertions, 51 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index b939c9c95..0f06c8520 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -17,7 +17,7 @@ __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv', import functools import operator import warnings -from typing import NamedTuple +from typing import NamedTuple, Any from .._utils import set_module from numpy.core import ( @@ -38,24 +38,24 @@ from numpy._typing import NDArray class EigResult(NamedTuple): eigenvalues: NDArray[Any] - eigenvectors: NDArray + eigenvectors: NDArray[Any] class EighResult(NamedTuple): - eigenvalues: NDArray - eigenvectors: NDArray + eigenvalues: NDArray[Any] + eigenvectors: NDArray[Any] class QRResult(NamedTuple): - Q: NDArray - R: NDArray + Q: NDArray[Any] + R: NDArray[Any] class SlogdetResult(NamedTuple): - sign: NDArray - logabsdet: NDArray + sign: NDArray[Any] + logabsdet: NDArray[Any] class SVDResult(NamedTuple): - U: NDArray - S: NDArray - Vh: NDArray + U: NDArray[Any] + S: NDArray[Any] + Vh: NDArray[Any] array_function_dispatch = functools.partial( overrides.array_function_dispatch, module='numpy.linalg') diff --git a/numpy/linalg/linalg.pyi b/numpy/linalg/linalg.pyi index 20cdb708b..c0b2f29b2 100644 --- a/numpy/linalg/linalg.pyi +++ b/numpy/linalg/linalg.pyi @@ -6,6 +6,8 @@ from typing import ( Any, SupportsIndex, SupportsInt, + NamedTuple, + Generic, ) from numpy import ( @@ -31,12 +33,37 @@ from numpy._typing import ( _T = TypeVar("_T") _ArrayType = TypeVar("_ArrayType", bound=NDArray[Any]) +_SCT = TypeVar("_SCT", bound=generic, covariant=True) +_SCT2 = TypeVar("_SCT2", bound=generic, covariant=True) _2Tuple = tuple[_T, _T] _ModeKind = L["reduced", "complete", "r", "raw"] __all__: list[str] +class EigResult(NamedTuple): + eigenvalues: NDArray[Any] + eigenvectors: NDArray[Any] + +class EighResult(NamedTuple): + eigenvalues: NDArray[Any] + eigenvectors: NDArray[Any] + +class QRResult(NamedTuple): + Q: NDArray[Any] + R: NDArray[Any] + +class SlogdetResult(NamedTuple): + # TODO: `sign` and `logabsdet` are scalars for input 2D arrays and + # a `(x.ndim - 2)`` dimensionl arrays otherwise + sign: Any + logabsdet: Any + +class SVDResult(NamedTuple): + U: NDArray[Any] + S: NDArray[Any] + Vh: NDArray[Any] + @overload def tensorsolve( a: _ArrayLikeInt_co, @@ -110,11 +137,11 @@ def cholesky(a: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... def cholesky(a: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ... @overload -def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[float64]]: ... +def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> QRResult: ... @overload -def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[floating[Any]]]: ... +def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> QRResult: ... @overload -def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[complexfloating[Any, Any]]]: ... +def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> QRResult: ... @overload def eigvals(a: _ArrayLikeInt_co) -> NDArray[float64] | NDArray[complex128]: ... @@ -129,27 +156,27 @@ def eigvalsh(a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[ def eigvalsh(a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[floating[Any]]: ... @overload -def eig(a: _ArrayLikeInt_co) -> _2Tuple[NDArray[float64]] | _2Tuple[NDArray[complex128]]: ... +def eig(a: _ArrayLikeInt_co) -> EigResult: ... @overload -def eig(a: _ArrayLikeFloat_co) -> _2Tuple[NDArray[floating[Any]]] | _2Tuple[NDArray[complexfloating[Any, Any]]]: ... +def eig(a: _ArrayLikeFloat_co) -> EigResult: ... @overload -def eig(a: _ArrayLikeComplex_co) -> _2Tuple[NDArray[complexfloating[Any, Any]]]: ... +def eig(a: _ArrayLikeComplex_co) -> EigResult: ... @overload def eigh( a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = ..., -) -> tuple[NDArray[float64], NDArray[float64]]: ... +) -> EighResult: ... @overload def eigh( a: _ArrayLikeFloat_co, UPLO: L["L", "U", "l", "u"] = ..., -) -> tuple[NDArray[floating[Any]], NDArray[floating[Any]]]: ... +) -> EighResult: ... @overload def eigh( a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = ..., -) -> tuple[NDArray[floating[Any]], NDArray[complexfloating[Any, Any]]]: ... +) -> EighResult: ... @overload def svd( @@ -157,33 +184,21 @@ def svd( full_matrices: bool = ..., compute_uv: L[True] = ..., hermitian: bool = ..., -) -> tuple[ - NDArray[float64], - NDArray[float64], - NDArray[float64], -]: ... +) -> SVDResult: ... @overload def svd( a: _ArrayLikeFloat_co, full_matrices: bool = ..., compute_uv: L[True] = ..., hermitian: bool = ..., -) -> tuple[ - NDArray[floating[Any]], - NDArray[floating[Any]], - NDArray[floating[Any]], -]: ... +) -> SVDResult: ... @overload def svd( a: _ArrayLikeComplex_co, full_matrices: bool = ..., compute_uv: L[True] = ..., hermitian: bool = ..., -) -> tuple[ - NDArray[complexfloating[Any, Any]], - NDArray[floating[Any]], - NDArray[complexfloating[Any, Any]], -]: ... +) -> SVDResult: ... @overload def svd( a: _ArrayLikeInt_co, @@ -231,7 +246,7 @@ def pinv( # TODO: Returns a 2-tuple of scalars for 2D arrays and # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise -def slogdet(a: _ArrayLikeComplex_co) -> _2Tuple[Any]: ... +def slogdet(a: _ArrayLikeComplex_co) -> SlogdetResult: ... # TODO: Returns a 2-tuple of scalars for 2D arrays and # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise diff --git a/numpy/typing/tests/data/reveal/linalg.pyi b/numpy/typing/tests/data/reveal/linalg.pyi index 19e13aed6..130351864 100644 --- a/numpy/typing/tests/data/reveal/linalg.pyi +++ b/numpy/typing/tests/data/reveal/linalg.pyi @@ -33,9 +33,9 @@ reveal_type(np.linalg.cholesky(AR_i8)) # E: ndarray[Any, dtype[{float64}]] reveal_type(np.linalg.cholesky(AR_f8)) # E: ndarray[Any, dtype[floating[Any]]] reveal_type(np.linalg.cholesky(AR_c16)) # E: ndarray[Any, dtype[complexfloating[Any, Any]]] -reveal_type(np.linalg.qr(AR_i8)) # E: Tuple[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]]] -reveal_type(np.linalg.qr(AR_f8)) # E: Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]]] -reveal_type(np.linalg.qr(AR_c16)) # E: Tuple[ndarray[Any, dtype[complexfloating[Any, Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]] +reveal_type(np.linalg.qr(AR_i8)) # E: QRResult +reveal_type(np.linalg.qr(AR_f8)) # E: QRResult +reveal_type(np.linalg.qr(AR_c16)) # E: QRResult reveal_type(np.linalg.eigvals(AR_i8)) # E: Union[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{complex128}]]] reveal_type(np.linalg.eigvals(AR_f8)) # E: Union[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]] @@ -45,17 +45,17 @@ reveal_type(np.linalg.eigvalsh(AR_i8)) # E: ndarray[Any, dtype[{float64}]] reveal_type(np.linalg.eigvalsh(AR_f8)) # E: ndarray[Any, dtype[floating[Any]]] reveal_type(np.linalg.eigvalsh(AR_c16)) # E: ndarray[Any, dtype[floating[Any]]] -reveal_type(np.linalg.eig(AR_i8)) # E: Union[Tuple[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]]], Tuple[ndarray[Any, dtype[{complex128}]], ndarray[Any, dtype[{complex128}]]]] -reveal_type(np.linalg.eig(AR_f8)) # E: Union[Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]]], Tuple[ndarray[Any, dtype[complexfloating[Any, Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]]] -reveal_type(np.linalg.eig(AR_c16)) # E: Tuple[ndarray[Any, dtype[complexfloating[Any, Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]] +reveal_type(np.linalg.eig(AR_i8)) # E: EigResult +reveal_type(np.linalg.eig(AR_f8)) # E: EigResult +reveal_type(np.linalg.eig(AR_c16)) # E: EigResult -reveal_type(np.linalg.eigh(AR_i8)) # E: Tuple[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]]] -reveal_type(np.linalg.eigh(AR_f8)) # E: Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]]] -reveal_type(np.linalg.eigh(AR_c16)) # E: Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]] +reveal_type(np.linalg.eigh(AR_i8)) # E: EighResult +reveal_type(np.linalg.eigh(AR_f8)) # E: EighResult +reveal_type(np.linalg.eigh(AR_c16)) # E: EighResult -reveal_type(np.linalg.svd(AR_i8)) # E: Tuple[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]]] -reveal_type(np.linalg.svd(AR_f8)) # E: Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]]] -reveal_type(np.linalg.svd(AR_c16)) # E: Tuple[ndarray[Any, dtype[complexfloating[Any, Any]]], ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]] +reveal_type(np.linalg.svd(AR_i8)) # E: SVDResult +reveal_type(np.linalg.svd(AR_f8)) # E: SVDResult +reveal_type(np.linalg.svd(AR_c16)) # E: SVDResult reveal_type(np.linalg.svd(AR_i8, compute_uv=False)) # E: ndarray[Any, dtype[{float64}]] reveal_type(np.linalg.svd(AR_f8, compute_uv=False)) # E: ndarray[Any, dtype[floating[Any]]] reveal_type(np.linalg.svd(AR_c16, compute_uv=False)) # E: ndarray[Any, dtype[floating[Any]]] @@ -72,9 +72,9 @@ reveal_type(np.linalg.pinv(AR_i8)) # E: ndarray[Any, dtype[{float64}]] reveal_type(np.linalg.pinv(AR_f8)) # E: ndarray[Any, dtype[floating[Any]]] reveal_type(np.linalg.pinv(AR_c16)) # E: ndarray[Any, dtype[complexfloating[Any, Any]]] -reveal_type(np.linalg.slogdet(AR_i8)) # E: Tuple[Any, Any] -reveal_type(np.linalg.slogdet(AR_f8)) # E: Tuple[Any, Any] -reveal_type(np.linalg.slogdet(AR_c16)) # E: Tuple[Any, Any] +reveal_type(np.linalg.slogdet(AR_i8)) # E: SlogdetResult +reveal_type(np.linalg.slogdet(AR_f8)) # E: SlogdetResult +reveal_type(np.linalg.slogdet(AR_c16)) # E: SlogdetResult reveal_type(np.linalg.det(AR_i8)) # E: Any reveal_type(np.linalg.det(AR_f8)) # E: Any |