summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <bas.vanbeek@hotmail.com>2023-05-17 21:16:35 +0200
committerBas van Beek <bas.vanbeek@hotmail.com>2023-05-17 21:16:35 +0200
commita4c249653ec9d063a67a6cde8123dca2defb8f8b (patch)
tree659a1e5d15cf070ca3954ba94108adddc114ab43
parent610f85c2ee769853f1d2c291476b49face35c691 (diff)
downloadnumpy-a4c249653ec9d063a67a6cde8123dca2defb8f8b.tar.gz
TYP: Update type annotations for the new linalg named tuples
-rw-r--r--numpy/linalg/linalg.py22
-rw-r--r--numpy/linalg/linalg.pyi65
-rw-r--r--numpy/typing/tests/data/reveal/linalg.pyi30
3 files changed, 66 insertions, 51 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index b939c9c95..0f06c8520 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -17,7 +17,7 @@ __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv',
import functools
import operator
import warnings
-from typing import NamedTuple
+from typing import NamedTuple, Any
from .._utils import set_module
from numpy.core import (
@@ -38,24 +38,24 @@ from numpy._typing import NDArray
class EigResult(NamedTuple):
eigenvalues: NDArray[Any]
- eigenvectors: NDArray
+ eigenvectors: NDArray[Any]
class EighResult(NamedTuple):
- eigenvalues: NDArray
- eigenvectors: NDArray
+ eigenvalues: NDArray[Any]
+ eigenvectors: NDArray[Any]
class QRResult(NamedTuple):
- Q: NDArray
- R: NDArray
+ Q: NDArray[Any]
+ R: NDArray[Any]
class SlogdetResult(NamedTuple):
- sign: NDArray
- logabsdet: NDArray
+ sign: NDArray[Any]
+ logabsdet: NDArray[Any]
class SVDResult(NamedTuple):
- U: NDArray
- S: NDArray
- Vh: NDArray
+ U: NDArray[Any]
+ S: NDArray[Any]
+ Vh: NDArray[Any]
array_function_dispatch = functools.partial(
overrides.array_function_dispatch, module='numpy.linalg')
diff --git a/numpy/linalg/linalg.pyi b/numpy/linalg/linalg.pyi
index 20cdb708b..c0b2f29b2 100644
--- a/numpy/linalg/linalg.pyi
+++ b/numpy/linalg/linalg.pyi
@@ -6,6 +6,8 @@ from typing import (
Any,
SupportsIndex,
SupportsInt,
+ NamedTuple,
+ Generic,
)
from numpy import (
@@ -31,12 +33,37 @@ from numpy._typing import (
_T = TypeVar("_T")
_ArrayType = TypeVar("_ArrayType", bound=NDArray[Any])
+_SCT = TypeVar("_SCT", bound=generic, covariant=True)
+_SCT2 = TypeVar("_SCT2", bound=generic, covariant=True)
_2Tuple = tuple[_T, _T]
_ModeKind = L["reduced", "complete", "r", "raw"]
__all__: list[str]
+class EigResult(NamedTuple):
+ eigenvalues: NDArray[Any]
+ eigenvectors: NDArray[Any]
+
+class EighResult(NamedTuple):
+ eigenvalues: NDArray[Any]
+ eigenvectors: NDArray[Any]
+
+class QRResult(NamedTuple):
+ Q: NDArray[Any]
+ R: NDArray[Any]
+
+class SlogdetResult(NamedTuple):
+ # TODO: `sign` and `logabsdet` are scalars for input 2D arrays and
+ # a `(x.ndim - 2)`` dimensionl arrays otherwise
+ sign: Any
+ logabsdet: Any
+
+class SVDResult(NamedTuple):
+ U: NDArray[Any]
+ S: NDArray[Any]
+ Vh: NDArray[Any]
+
@overload
def tensorsolve(
a: _ArrayLikeInt_co,
@@ -110,11 +137,11 @@ def cholesky(a: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ...
def cholesky(a: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
@overload
-def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[float64]]: ...
+def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> QRResult: ...
@overload
-def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[floating[Any]]]: ...
+def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> QRResult: ...
@overload
-def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[complexfloating[Any, Any]]]: ...
+def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> QRResult: ...
@overload
def eigvals(a: _ArrayLikeInt_co) -> NDArray[float64] | NDArray[complex128]: ...
@@ -129,27 +156,27 @@ def eigvalsh(a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[
def eigvalsh(a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[floating[Any]]: ...
@overload
-def eig(a: _ArrayLikeInt_co) -> _2Tuple[NDArray[float64]] | _2Tuple[NDArray[complex128]]: ...
+def eig(a: _ArrayLikeInt_co) -> EigResult: ...
@overload
-def eig(a: _ArrayLikeFloat_co) -> _2Tuple[NDArray[floating[Any]]] | _2Tuple[NDArray[complexfloating[Any, Any]]]: ...
+def eig(a: _ArrayLikeFloat_co) -> EigResult: ...
@overload
-def eig(a: _ArrayLikeComplex_co) -> _2Tuple[NDArray[complexfloating[Any, Any]]]: ...
+def eig(a: _ArrayLikeComplex_co) -> EigResult: ...
@overload
def eigh(
a: _ArrayLikeInt_co,
UPLO: L["L", "U", "l", "u"] = ...,
-) -> tuple[NDArray[float64], NDArray[float64]]: ...
+) -> EighResult: ...
@overload
def eigh(
a: _ArrayLikeFloat_co,
UPLO: L["L", "U", "l", "u"] = ...,
-) -> tuple[NDArray[floating[Any]], NDArray[floating[Any]]]: ...
+) -> EighResult: ...
@overload
def eigh(
a: _ArrayLikeComplex_co,
UPLO: L["L", "U", "l", "u"] = ...,
-) -> tuple[NDArray[floating[Any]], NDArray[complexfloating[Any, Any]]]: ...
+) -> EighResult: ...
@overload
def svd(
@@ -157,33 +184,21 @@ def svd(
full_matrices: bool = ...,
compute_uv: L[True] = ...,
hermitian: bool = ...,
-) -> tuple[
- NDArray[float64],
- NDArray[float64],
- NDArray[float64],
-]: ...
+) -> SVDResult: ...
@overload
def svd(
a: _ArrayLikeFloat_co,
full_matrices: bool = ...,
compute_uv: L[True] = ...,
hermitian: bool = ...,
-) -> tuple[
- NDArray[floating[Any]],
- NDArray[floating[Any]],
- NDArray[floating[Any]],
-]: ...
+) -> SVDResult: ...
@overload
def svd(
a: _ArrayLikeComplex_co,
full_matrices: bool = ...,
compute_uv: L[True] = ...,
hermitian: bool = ...,
-) -> tuple[
- NDArray[complexfloating[Any, Any]],
- NDArray[floating[Any]],
- NDArray[complexfloating[Any, Any]],
-]: ...
+) -> SVDResult: ...
@overload
def svd(
a: _ArrayLikeInt_co,
@@ -231,7 +246,7 @@ def pinv(
# TODO: Returns a 2-tuple of scalars for 2D arrays and
# a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
-def slogdet(a: _ArrayLikeComplex_co) -> _2Tuple[Any]: ...
+def slogdet(a: _ArrayLikeComplex_co) -> SlogdetResult: ...
# TODO: Returns a 2-tuple of scalars for 2D arrays and
# a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise
diff --git a/numpy/typing/tests/data/reveal/linalg.pyi b/numpy/typing/tests/data/reveal/linalg.pyi
index 19e13aed6..130351864 100644
--- a/numpy/typing/tests/data/reveal/linalg.pyi
+++ b/numpy/typing/tests/data/reveal/linalg.pyi
@@ -33,9 +33,9 @@ reveal_type(np.linalg.cholesky(AR_i8)) # E: ndarray[Any, dtype[{float64}]]
reveal_type(np.linalg.cholesky(AR_f8)) # E: ndarray[Any, dtype[floating[Any]]]
reveal_type(np.linalg.cholesky(AR_c16)) # E: ndarray[Any, dtype[complexfloating[Any, Any]]]
-reveal_type(np.linalg.qr(AR_i8)) # E: Tuple[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]]]
-reveal_type(np.linalg.qr(AR_f8)) # E: Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]]]
-reveal_type(np.linalg.qr(AR_c16)) # E: Tuple[ndarray[Any, dtype[complexfloating[Any, Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]]
+reveal_type(np.linalg.qr(AR_i8)) # E: QRResult
+reveal_type(np.linalg.qr(AR_f8)) # E: QRResult
+reveal_type(np.linalg.qr(AR_c16)) # E: QRResult
reveal_type(np.linalg.eigvals(AR_i8)) # E: Union[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{complex128}]]]
reveal_type(np.linalg.eigvals(AR_f8)) # E: Union[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]]
@@ -45,17 +45,17 @@ reveal_type(np.linalg.eigvalsh(AR_i8)) # E: ndarray[Any, dtype[{float64}]]
reveal_type(np.linalg.eigvalsh(AR_f8)) # E: ndarray[Any, dtype[floating[Any]]]
reveal_type(np.linalg.eigvalsh(AR_c16)) # E: ndarray[Any, dtype[floating[Any]]]
-reveal_type(np.linalg.eig(AR_i8)) # E: Union[Tuple[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]]], Tuple[ndarray[Any, dtype[{complex128}]], ndarray[Any, dtype[{complex128}]]]]
-reveal_type(np.linalg.eig(AR_f8)) # E: Union[Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]]], Tuple[ndarray[Any, dtype[complexfloating[Any, Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]]]
-reveal_type(np.linalg.eig(AR_c16)) # E: Tuple[ndarray[Any, dtype[complexfloating[Any, Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]]
+reveal_type(np.linalg.eig(AR_i8)) # E: EigResult
+reveal_type(np.linalg.eig(AR_f8)) # E: EigResult
+reveal_type(np.linalg.eig(AR_c16)) # E: EigResult
-reveal_type(np.linalg.eigh(AR_i8)) # E: Tuple[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]]]
-reveal_type(np.linalg.eigh(AR_f8)) # E: Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]]]
-reveal_type(np.linalg.eigh(AR_c16)) # E: Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]]
+reveal_type(np.linalg.eigh(AR_i8)) # E: EighResult
+reveal_type(np.linalg.eigh(AR_f8)) # E: EighResult
+reveal_type(np.linalg.eigh(AR_c16)) # E: EighResult
-reveal_type(np.linalg.svd(AR_i8)) # E: Tuple[ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]], ndarray[Any, dtype[{float64}]]]
-reveal_type(np.linalg.svd(AR_f8)) # E: Tuple[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[floating[Any]]]]
-reveal_type(np.linalg.svd(AR_c16)) # E: Tuple[ndarray[Any, dtype[complexfloating[Any, Any]]], ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]]
+reveal_type(np.linalg.svd(AR_i8)) # E: SVDResult
+reveal_type(np.linalg.svd(AR_f8)) # E: SVDResult
+reveal_type(np.linalg.svd(AR_c16)) # E: SVDResult
reveal_type(np.linalg.svd(AR_i8, compute_uv=False)) # E: ndarray[Any, dtype[{float64}]]
reveal_type(np.linalg.svd(AR_f8, compute_uv=False)) # E: ndarray[Any, dtype[floating[Any]]]
reveal_type(np.linalg.svd(AR_c16, compute_uv=False)) # E: ndarray[Any, dtype[floating[Any]]]
@@ -72,9 +72,9 @@ reveal_type(np.linalg.pinv(AR_i8)) # E: ndarray[Any, dtype[{float64}]]
reveal_type(np.linalg.pinv(AR_f8)) # E: ndarray[Any, dtype[floating[Any]]]
reveal_type(np.linalg.pinv(AR_c16)) # E: ndarray[Any, dtype[complexfloating[Any, Any]]]
-reveal_type(np.linalg.slogdet(AR_i8)) # E: Tuple[Any, Any]
-reveal_type(np.linalg.slogdet(AR_f8)) # E: Tuple[Any, Any]
-reveal_type(np.linalg.slogdet(AR_c16)) # E: Tuple[Any, Any]
+reveal_type(np.linalg.slogdet(AR_i8)) # E: SlogdetResult
+reveal_type(np.linalg.slogdet(AR_f8)) # E: SlogdetResult
+reveal_type(np.linalg.slogdet(AR_c16)) # E: SlogdetResult
reveal_type(np.linalg.det(AR_i8)) # E: Any
reveal_type(np.linalg.det(AR_f8)) # E: Any