summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-02-12 07:32:29 -0700
committerGitHub <noreply@github.com>2021-02-12 07:32:29 -0700
commita5dc2b5b917fc50575e10bbe139a0c78e43a1c1c (patch)
tree2096648daf6f93f67e9491531237f7144d51a472
parente1ec9e868bf03ec4284b9a5fb291418983b3a4fb (diff)
parentbea70e9443a2eaf2134aa450b019d702d711eab6 (diff)
downloadnumpy-a5dc2b5b917fc50575e10bbe139a0c78e43a1c1c.tar.gz
Merge pull request #18390 from BvB93/einsum
ENH: Add annotations for `np.core.einsumfunc`
-rw-r--r--numpy/__init__.pyi7
-rw-r--r--numpy/core/einsumfunc.pyi138
-rw-r--r--numpy/typing/__init__.py1
-rw-r--r--numpy/typing/_dtype_like.py8
-rw-r--r--numpy/typing/tests/data/fail/einsumfunc.py15
-rw-r--r--numpy/typing/tests/data/pass/einsumfunc.py37
-rw-r--r--numpy/typing/tests/data/reveal/einsumfunc.py32
-rw-r--r--numpy/typing/tests/test_typing.py2
8 files changed, 237 insertions, 3 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 0e9deef61..1c52c7285 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -281,6 +281,11 @@ from numpy.core.arrayprint import (
printoptions as printoptions,
)
+from numpy.core.einsumfunc import (
+ einsum as einsum,
+ einsum_path as einsum_path,
+)
+
from numpy.core.numeric import (
zeros_like as zeros_like,
ones as ones,
@@ -401,8 +406,6 @@ dot: Any
dsplit: Any
dstack: Any
ediff1d: Any
-einsum: Any
-einsum_path: Any
expand_dims: Any
extract: Any
eye: Any
diff --git a/numpy/core/einsumfunc.pyi b/numpy/core/einsumfunc.pyi
new file mode 100644
index 000000000..b33aff29f
--- /dev/null
+++ b/numpy/core/einsumfunc.pyi
@@ -0,0 +1,138 @@
+import sys
+from typing import List, TypeVar, Optional, Any, overload, Union, Tuple, Sequence
+
+from numpy import (
+ ndarray,
+ dtype,
+ bool_,
+ unsignedinteger,
+ signedinteger,
+ floating,
+ complexfloating,
+ number,
+ _OrderKACF,
+)
+from numpy.typing import (
+ _ArrayOrScalar,
+ _ArrayLikeBool_co,
+ _ArrayLikeUInt_co,
+ _ArrayLikeInt_co,
+ _ArrayLikeFloat_co,
+ _ArrayLikeComplex_co,
+ _DTypeLikeBool,
+ _DTypeLikeUInt,
+ _DTypeLikeInt,
+ _DTypeLikeFloat,
+ _DTypeLikeComplex,
+ _DTypeLikeComplex_co,
+)
+
+if sys.version_info >= (3, 8):
+ from typing import Literal
+else:
+ from typing_extensions import Literal
+
+_ArrayType = TypeVar(
+ "_ArrayType",
+ bound=ndarray[Any, dtype[Union[bool_, number[Any]]]],
+)
+
+_OptimizeKind = Union[
+ None, bool, Literal["greedy", "optimal"], Sequence[Any]
+]
+_CastingSafe = Literal["no", "equiv", "safe", "same_kind"]
+_CastingUnsafe = Literal["unsafe"]
+
+__all__: List[str]
+
+# TODO: Properly handle the `casting`-based combinatorics
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeBool_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeBool] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[bool_]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeUInt_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeUInt] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[unsignedinteger[Any]]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeInt_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeInt] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[signedinteger[Any]]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeFloat_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeFloat] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[floating[Any]]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeComplex_co,
+ out: None = ...,
+ dtype: Optional[_DTypeLikeComplex] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[complexfloating[Any, Any]]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: Any,
+ casting: _CastingUnsafe,
+ dtype: Optional[_DTypeLikeComplex_co] = ...,
+ out: None = ...,
+ order: _OrderKACF = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayOrScalar[Any]: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: _ArrayLikeComplex_co,
+ out: _ArrayType,
+ dtype: Optional[_DTypeLikeComplex_co] = ...,
+ order: _OrderKACF = ...,
+ casting: _CastingSafe = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayType: ...
+@overload
+def einsum(
+ __subscripts: str,
+ *operands: Any,
+ out: _ArrayType,
+ casting: _CastingUnsafe,
+ dtype: Optional[_DTypeLikeComplex_co] = ...,
+ order: _OrderKACF = ...,
+ optimize: _OptimizeKind = ...,
+) -> _ArrayType: ...
+
+# NOTE: `einsum_call` is a hidden kwarg unavailable for public use.
+# It is therefore excluded from the signatures below.
+# NOTE: In practice the list consists of a `str` (first element)
+# and a variable number of integer tuples.
+def einsum_path(
+ __subscripts: str,
+ *operands: _ArrayLikeComplex_co,
+ optimize: _OptimizeKind = ...,
+) -> Tuple[List[Any], str]: ...
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index 8f5df483b..61d780b85 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -317,6 +317,7 @@ from ._dtype_like import (
_DTypeLikeVoid,
_DTypeLikeStr,
_DTypeLikeBytes,
+ _DTypeLikeComplex_co,
)
from ._array_like import (
ArrayLike as ArrayLike,
diff --git a/numpy/typing/_dtype_like.py b/numpy/typing/_dtype_like.py
index f86b4a67c..a41e2f358 100644
--- a/numpy/typing/_dtype_like.py
+++ b/numpy/typing/_dtype_like.py
@@ -228,3 +228,11 @@ _DTypeLikeObject = Union[
"_SupportsDType[np.dtype[np.object_]]",
_ObjectCodes,
]
+
+_DTypeLikeComplex_co = Union[
+ _DTypeLikeBool,
+ _DTypeLikeUInt,
+ _DTypeLikeInt,
+ _DTypeLikeFloat,
+ _DTypeLikeComplex,
+]
diff --git a/numpy/typing/tests/data/fail/einsumfunc.py b/numpy/typing/tests/data/fail/einsumfunc.py
new file mode 100644
index 000000000..33722f861
--- /dev/null
+++ b/numpy/typing/tests/data/fail/einsumfunc.py
@@ -0,0 +1,15 @@
+from typing import List, Any
+import numpy as np
+
+AR_i: np.ndarray[Any, np.dtype[np.int64]]
+AR_f: np.ndarray[Any, np.dtype[np.float64]]
+AR_m: np.ndarray[Any, np.dtype[np.timedelta64]]
+AR_O: np.ndarray[Any, np.dtype[np.object_]]
+AR_U: np.ndarray[Any, np.dtype[np.str_]]
+
+np.einsum("i,i->i", AR_i, AR_m) # E: incompatible type
+np.einsum("i,i->i", AR_O, AR_O) # E: incompatible type
+np.einsum("i,i->i", AR_f, AR_f, dtype=np.int32) # E: incompatible type
+np.einsum("i,i->i", AR_i, AR_i, dtype=np.timedelta64, casting="unsafe") # E: No overload variant
+np.einsum("i,i->i", AR_i, AR_i, out=AR_U) # E: Value of type variable "_ArrayType" of "einsum" cannot be
+np.einsum("i,i->i", AR_i, AR_i, out=AR_U, casting="unsafe") # E: No overload variant
diff --git a/numpy/typing/tests/data/pass/einsumfunc.py b/numpy/typing/tests/data/pass/einsumfunc.py
new file mode 100644
index 000000000..914eed4cc
--- /dev/null
+++ b/numpy/typing/tests/data/pass/einsumfunc.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from typing import List, Any
+
+import pytest
+import numpy as np
+
+AR_LIKE_b = [True, True, True]
+AR_LIKE_u = [np.uint32(1), np.uint32(2), np.uint32(3)]
+AR_LIKE_i = [1, 2, 3]
+AR_LIKE_f = [1.0, 2.0, 3.0]
+AR_LIKE_c = [1j, 2j, 3j]
+AR_LIKE_U = ["1", "2", "3"]
+
+OUT_c: np.ndarray[Any, np.dtype[np.complex128]] = np.empty(3, dtype=np.complex128)
+
+np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_b)
+np.einsum("i,i->i", AR_LIKE_u, AR_LIKE_u)
+np.einsum("i,i->i", AR_LIKE_i, AR_LIKE_i)
+np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f)
+np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c)
+np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_i)
+np.einsum("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c)
+
+np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, dtype="c16")
+np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe")
+np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, out=OUT_c)
+with pytest.raises(np.ComplexWarning):
+ np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=float, casting="unsafe", out=OUT_c)
+
+np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_b)
+np.einsum_path("i,i->i", AR_LIKE_u, AR_LIKE_u)
+np.einsum_path("i,i->i", AR_LIKE_i, AR_LIKE_i)
+np.einsum_path("i,i->i", AR_LIKE_f, AR_LIKE_f)
+np.einsum_path("i,i->i", AR_LIKE_c, AR_LIKE_c)
+np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_i)
+np.einsum_path("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c)
diff --git a/numpy/typing/tests/data/reveal/einsumfunc.py b/numpy/typing/tests/data/reveal/einsumfunc.py
new file mode 100644
index 000000000..18c192b0b
--- /dev/null
+++ b/numpy/typing/tests/data/reveal/einsumfunc.py
@@ -0,0 +1,32 @@
+from typing import List, Any
+import numpy as np
+
+AR_LIKE_b: List[bool]
+AR_LIKE_u: List[np.uint32]
+AR_LIKE_i: List[int]
+AR_LIKE_f: List[float]
+AR_LIKE_c: List[complex]
+AR_LIKE_U: List[str]
+
+OUT_f: np.ndarray[Any, np.dtype[np.float64]]
+
+reveal_type(np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_b)) # E: Union[numpy.bool_, numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
+reveal_type(np.einsum("i,i->i", AR_LIKE_u, AR_LIKE_u)) # E: Union[numpy.unsignedinteger[Any], numpy.ndarray[Any, numpy.dtype[numpy.unsignedinteger[Any]]]
+reveal_type(np.einsum("i,i->i", AR_LIKE_i, AR_LIKE_i)) # E: Union[numpy.signedinteger[Any], numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
+reveal_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f)) # E: Union[numpy.floating[Any], numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]]
+reveal_type(np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c)) # E: Union[numpy.complexfloating[Any, Any], numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]]
+reveal_type(np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_i)) # E: Union[numpy.signedinteger[Any], numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
+reveal_type(np.einsum("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c)) # E: Union[numpy.complexfloating[Any, Any], numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]]
+
+reveal_type(np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c, out=OUT_f)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]
+reveal_type(np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe", out=OUT_f)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]
+reveal_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, dtype="c16")) # E: Union[numpy.complexfloating[Any, Any], numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]]
+reveal_type(np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe")) # E: Any
+
+reveal_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_b)) # E: Tuple[builtins.list[Any], builtins.str]
+reveal_type(np.einsum_path("i,i->i", AR_LIKE_u, AR_LIKE_u)) # E: Tuple[builtins.list[Any], builtins.str]
+reveal_type(np.einsum_path("i,i->i", AR_LIKE_i, AR_LIKE_i)) # E: Tuple[builtins.list[Any], builtins.str]
+reveal_type(np.einsum_path("i,i->i", AR_LIKE_f, AR_LIKE_f)) # E: Tuple[builtins.list[Any], builtins.str]
+reveal_type(np.einsum_path("i,i->i", AR_LIKE_c, AR_LIKE_c)) # E: Tuple[builtins.list[Any], builtins.str]
+reveal_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_i)) # E: Tuple[builtins.list[Any], builtins.str]
+reveal_type(np.einsum_path("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c)) # E: Tuple[builtins.list[Any], builtins.str]
diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py
index eb7e0b09e..e80282420 100644
--- a/numpy/typing/tests/test_typing.py
+++ b/numpy/typing/tests/test_typing.py
@@ -91,7 +91,7 @@ def test_success(path):
# Alias `OUTPUT_MYPY` so that it appears in the local namespace
output_mypy = OUTPUT_MYPY
if path in output_mypy:
- raise AssertionError("\n".join(v for v in output_mypy[path].values()))
+ raise AssertionError("\n".join(v for v in output_mypy[path]))
@pytest.mark.slow