summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Wilson <person142@users.noreply.github.com>2020-10-06 22:20:47 -0700
committerJosh Wilson <person142@users.noreply.github.com>2020-10-07 21:07:14 -0700
commit02688c220591250082d4ce109eb51421d8412099 (patch)
tree1cf0877dca790d2ab7218d2112c1ddfa4e10c951
parentfd0f3dd2723ed7effde52bf31a673c9128a0a28a (diff)
downloadnumpy-02688c220591250082d4ce109eb51421d8412099.tar.gz
MAINT: add more dtype __new__ overloads for missing scalar types
-rw-r--r--numpy/__init__.pyi282
-rw-r--r--numpy/typing/tests/data/reveal/dtype.py9
2 files changed, 270 insertions, 21 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 139f2a1bc..f4caaab7c 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -542,75 +542,307 @@ class dtype(Generic[_DTypeScalar]):
align: bool = ...,
copy: bool = ...,
) -> dtype[_DTypeScalar]: ...
- # Overloads for string aliases
+ # Overloads for string aliases, Python types, and some assorted
+ # other special cases. Order is sometimes important because of the
+ # subtype relationships
+ #
+ # bool < int < float < complex
+ #
+ # so we have to make sure the overloads for the narrowest type is
+ # first.
@overload
def __new__(
cls,
- dtype: Literal["float64", "f8", "<f8", ">f8", "float", "double", "float_", "d"],
+ dtype: Union[
+ Type[bool],
+ Literal[
+ "?",
+ "=?",
+ "<?",
+ ">?",
+ "bool",
+ "bool_",
+ ],
+ ],
align: bool = ...,
copy: bool = ...,
- ) -> dtype[float64]: ...
+ ) -> dtype[bool_]: ...
@overload
def __new__(
cls,
- dtype: Literal["float32", "f4", "<f4", ">f4", "single"],
+ dtype: Literal[
+ "uint8",
+ "u1",
+ "=u1",
+ "<u1",
+ ">u1",
+ "B",
+ "=B",
+ "<B",
+ ">B",
+ ],
align: bool = ...,
copy: bool = ...,
- ) -> dtype[float32]: ...
+ ) -> dtype[uint8]: ...
@overload
def __new__(
cls,
- dtype: Literal["int64", "i8", "<i8", ">i8"],
+ dtype: Literal[
+ "uint16",
+ "u2",
+ "=u2",
+ "<u2",
+ ">u2",
+ "h",
+ "=h",
+ "<h",
+ ">h",
+ ],
align: bool = ...,
copy: bool = ...,
- ) -> dtype[int64]: ...
+ ) -> dtype[uint16]: ...
@overload
def __new__(
cls,
- dtype: Literal["int32", "i4", "<i4", ">i4"],
+ dtype: Literal[
+ "uint32",
+ "u4",
+ "=u4",
+ "<u4",
+ ">u4",
+ "I",
+ "=I",
+ "<I",
+ ">I",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[uint32]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "uint64",
+ "u8",
+ "=u8",
+ "<u8",
+ ">u8",
+ "L",
+ "=L",
+ "<L",
+ ">L",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[uint64]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "int8",
+ "i1",
+ "=i1",
+ "<i1",
+ ">i1",
+ "b",
+ "=b",
+ "<b",
+ ">b",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[int8]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "int16",
+ "i2",
+ "=i2",
+ "<i2",
+ ">i2",
+ "h",
+ "=h",
+ "<h",
+ ">h",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[int16]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "int32",
+ "i4",
+ "=i4",
+ "<i4",
+ ">i4",
+ "i",
+ "=i",
+ "<i",
+ ">i",
+ ],
align: bool = ...,
copy: bool = ...,
) -> dtype[int32]: ...
- # "int" resolves to int_, which is system dependent, and as of now
- # untyped. Long-term we'll do something fancier here.
@overload
def __new__(
cls,
- dtype: Literal["int"],
+ dtype: Literal[
+ "int64",
+ "i8",
+ "=i8",
+ "<i8",
+ ">i8",
+ "l",
+ "=l",
+ "<l",
+ ">l",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[int64]: ...
+ # "int"/int resolve to int_, which is system dependent and as of
+ # now untyped. Long-term we'll do something fancier here.
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[Type[int], Literal["int"]],
align: bool = ...,
copy: bool = ...,
) -> dtype: ...
- # Overloads for Python types. Order is important here.
@overload
def __new__(
cls,
- dtype: Type[bool],
+ dtype: Literal[
+ "float16",
+ "f4",
+ "=f4",
+ "<f4",
+ ">f4",
+ "e",
+ "=e",
+ "<e",
+ ">e",
+ "half",
+ ],
align: bool = ...,
copy: bool = ...,
- ) -> dtype[bool_]: ...
- # See the notes for "int"
+ ) -> dtype[float16]: ...
@overload
def __new__(
cls,
- dtype: Type[int],
+ dtype: Literal[
+ "float32",
+ "f4",
+ "=f4",
+ "<f4",
+ ">f4",
+ "f",
+ "=f",
+ "<f",
+ ">f",
+ "single",
+ ],
align: bool = ...,
copy: bool = ...,
- ) -> dtype[Any]: ...
+ ) -> dtype[float32]: ...
@overload
def __new__(
cls,
- dtype: Type[float],
+ dtype: Union[
+ None,
+ Type[float],
+ Literal[
+ "float64",
+ "f8",
+ "=f8",
+ "<f8",
+ ">f8",
+ "d",
+ "<d",
+ ">d",
+ "float",
+ "double",
+ "float_",
+ ],
+ ],
align: bool = ...,
copy: bool = ...,
) -> dtype[float64]: ...
- # None is a special case
@overload
def __new__(
cls,
- dtype: None,
+ dtype: Literal[
+ "complex64",
+ "c8",
+ "=c8",
+ "<c8",
+ ">c8",
+ "F",
+ "=F",
+ "<F",
+ ">F",
+ ],
align: bool = ...,
copy: bool = ...,
- ) -> dtype[float64]: ...
+ ) -> dtype[complex128]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[
+ Type[complex],
+ Literal[
+ "complex128",
+ "c16",
+ "=c16",
+ "<c16",
+ ">c16",
+ "D",
+ "=D",
+ "<D",
+ ">D",
+ ],
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[complex128]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[
+ Type[bytes],
+ Literal[
+ "S",
+ "=S",
+ "<S",
+ ">S",
+ "bytes",
+ "bytes_",
+ ],
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[bytes_]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[
+ Type[str],
+ Literal[
+ "U",
+ "=U",
+ # <U and >U intentionally not included; they are not
+ # the same dtype and which one dtype("U") translates
+ # to is platform-dependent.
+ "str",
+ "str_",
+ ],
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[str_]: ...
# dtype of a dtype is the same dtype
@overload
def __new__(
@@ -627,6 +859,14 @@ class dtype(Generic[_DTypeScalar]):
align: bool = ...,
copy: bool = ...,
) -> dtype[Any]: ...
+ # Handle strings that can't be expressed as literals; i.e. s1, s2, ...
+ @overload
+ def __new__(
+ cls,
+ dtype: str,
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[Any]: ...
# Catchall overload
@overload
def __new__(
diff --git a/numpy/typing/tests/data/reveal/dtype.py b/numpy/typing/tests/data/reveal/dtype.py
index aca7e8a5e..e0802299e 100644
--- a/numpy/typing/tests/data/reveal/dtype.py
+++ b/numpy/typing/tests/data/reveal/dtype.py
@@ -8,11 +8,17 @@ reveal_type(np.dtype("float64")) # E: numpy.dtype[numpy.float64]
reveal_type(np.dtype("float32")) # E: numpy.dtype[numpy.float32]
reveal_type(np.dtype("int64")) # E: numpy.dtype[numpy.int64]
reveal_type(np.dtype("int32")) # E: numpy.dtype[numpy.int32]
+reveal_type(np.dtype("bool")) # E: numpy.dtype[numpy.bool_]
+reveal_type(np.dtype("bytes")) # E: numpy.dtype[numpy.bytes_]
+reveal_type(np.dtype("str")) # E: numpy.dtype[numpy.str_]
# Python types
+reveal_type(np.dtype(complex)) # E: numpy.dtype[numpy.complex128]
reveal_type(np.dtype(float)) # E: numpy.dtype[numpy.float64]
reveal_type(np.dtype(int)) # E: numpy.dtype
reveal_type(np.dtype(bool)) # E: numpy.dtype[numpy.bool_]
+reveal_type(np.dtype(str)) # E: numpy.dtype[numpy.str_]
+reveal_type(np.dtype(bytes)) # E: numpy.dtype[numpy.bytes_]
# Special case for None
reveal_type(np.dtype(None)) # E: numpy.dtype[numpy.float64]
@@ -20,5 +26,8 @@ reveal_type(np.dtype(None)) # E: numpy.dtype[numpy.float64]
# Dtypes of dtypes
reveal_type(np.dtype(np.dtype(np.float64))) # E: numpy.dtype[numpy.float64*]
+# Parameterized dtypes
+reveal_type(np.dtype("S8")) # E: numpy.dtype
+
# Void
reveal_type(np.dtype(("U", 10))) # E: numpy.dtype[numpy.void]