summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@gmail.com>2021-01-29 12:07:01 +0000
committerGitHub <noreply@github.com>2021-01-29 12:07:01 +0000
commit9772bea95f3d079c2b59107ddd9731602b5edde0 (patch)
tree4738e7abad35b274431445bad3cf9a9a9181737c
parentb969f8c9a2178bec235d6c47db3e7cd68c5d39c1 (diff)
parentd788788f8233d96e647c64bd4070305146b9ff92 (diff)
downloadnumpy-9772bea95f3d079c2b59107ddd9731602b5edde0.tar.gz
Merge pull request #18236 from BvB93/dtype-like
ENH: Add aliases for commonly used dtype-like objects
-rw-r--r--numpy/__init__.pyi5
-rw-r--r--numpy/typing/__init__.py17
-rw-r--r--numpy/typing/_dtype_like.py161
-rw-r--r--numpy/typing/tests/data/fail/dtype.py12
4 files changed, 182 insertions, 13 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 911b496df..fe9dc5914 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -682,14 +682,13 @@ class dtype(Generic[_DTypeScalar_co]):
align: bool = ...,
copy: bool = ...,
) -> dtype[_DTypeScalar_co]: ...
- # TODO: handle _SupportsDType better
@overload
def __new__(
cls,
- dtype: _SupportsDType,
+ dtype: _SupportsDType[dtype[_DTypeScalar_co]],
align: bool = ...,
copy: bool = ...,
- ) -> dtype[Any]: ...
+ ) -> dtype[_DTypeScalar_co]: ...
# Handle strings that can't be expressed as literals; i.e. s1, s2, ...
@overload
def __new__(
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index 4ec1f4b2f..8147789fb 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -297,7 +297,22 @@ from ._scalars import (
_VoidLike_co,
)
from ._shape import _Shape, _ShapeLike
-from ._dtype_like import _SupportsDType, _VoidDTypeLike, DTypeLike as DTypeLike
+from ._dtype_like import (
+ DTypeLike as DTypeLike,
+ _SupportsDType,
+ _VoidDTypeLike,
+ _DTypeLikeBool,
+ _DTypeLikeUInt,
+ _DTypeLikeInt,
+ _DTypeLikeFloat,
+ _DTypeLikeComplex,
+ _DTypeLikeTD64,
+ _DTypeLikeDT64,
+ _DTypeLikeObject,
+ _DTypeLikeVoid,
+ _DTypeLikeStr,
+ _DTypeLikeBytes,
+)
from ._array_like import (
ArrayLike as ArrayLike,
_ArrayLike,
diff --git a/numpy/typing/_dtype_like.py b/numpy/typing/_dtype_like.py
index 1953bd5fc..f86b4a67c 100644
--- a/numpy/typing/_dtype_like.py
+++ b/numpy/typing/_dtype_like.py
@@ -1,7 +1,7 @@
import sys
-from typing import Any, List, Sequence, Tuple, Union, TYPE_CHECKING
+from typing import Any, List, Sequence, Tuple, Union, Type, TypeVar, TYPE_CHECKING
-from numpy import dtype
+import numpy as np
from ._shape import _ShapeLike
if sys.version_info >= (3, 8):
@@ -15,6 +15,48 @@ else:
else:
HAVE_PROTOCOL = True
+from ._char_codes import (
+ _BoolCodes,
+ _UInt8Codes,
+ _UInt16Codes,
+ _UInt32Codes,
+ _UInt64Codes,
+ _Int8Codes,
+ _Int16Codes,
+ _Int32Codes,
+ _Int64Codes,
+ _Float16Codes,
+ _Float32Codes,
+ _Float64Codes,
+ _Complex64Codes,
+ _Complex128Codes,
+ _ByteCodes,
+ _ShortCodes,
+ _IntCCodes,
+ _IntPCodes,
+ _IntCodes,
+ _LongLongCodes,
+ _UByteCodes,
+ _UShortCodes,
+ _UIntCCodes,
+ _UIntPCodes,
+ _UIntCodes,
+ _ULongLongCodes,
+ _HalfCodes,
+ _SingleCodes,
+ _DoubleCodes,
+ _LongDoubleCodes,
+ _CSingleCodes,
+ _CDoubleCodes,
+ _CLongDoubleCodes,
+ _DT64Codes,
+ _TD64Codes,
+ _StrCodes,
+ _BytesCodes,
+ _VoidCodes,
+ _ObjectCodes,
+)
+
_DTypeLikeNested = Any # TODO: wait for support for recursive types
if TYPE_CHECKING or HAVE_PROTOCOL:
@@ -30,9 +72,12 @@ if TYPE_CHECKING or HAVE_PROTOCOL:
itemsize: int
aligned: bool
+ _DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype)
+
# A protocol for anything with the dtype attribute
- class _SupportsDType(Protocol):
- dtype: _DTypeLikeNested
+ class _SupportsDType(Protocol[_DType_co]):
+ @property
+ def dtype(self) -> _DType_co: ...
else:
_DTypeDict = Any
@@ -61,13 +106,13 @@ _VoidDTypeLike = Union[
# Anything that can be coerced into numpy.dtype.
# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
DTypeLike = Union[
- dtype,
+ np.dtype,
# default data type (float64)
None,
# array-scalar types and generic types
type, # TODO: enumerate these when we add type hints for numpy scalars
# anything with a dtype attribute
- _SupportsDType,
+ "_SupportsDType[np.dtype[Any]]",
# character codes, type strings or comma-separated fields, e.g., 'float64'
str,
_VoidDTypeLike,
@@ -79,3 +124,107 @@ DTypeLike = Union[
# therefore not included in the Union defining `DTypeLike`.
#
# See https://github.com/numpy/numpy/issues/16891 for more details.
+
+# Aliases for commonly used dtype-like objects.
+# Note that the precision of `np.number` subclasses is ignored herein.
+_DTypeLikeBool = Union[
+ Type[bool],
+ Type[np.bool_],
+ "np.dtype[np.bool_]",
+ "_SupportsDType[np.dtype[np.bool_]]",
+ _BoolCodes,
+]
+_DTypeLikeUInt = Union[
+ Type[np.unsignedinteger],
+ "np.dtype[np.unsignedinteger]",
+ "_SupportsDType[np.dtype[np.unsignedinteger]]",
+ _UInt8Codes,
+ _UInt16Codes,
+ _UInt32Codes,
+ _UInt64Codes,
+ _UByteCodes,
+ _UShortCodes,
+ _UIntCCodes,
+ _UIntPCodes,
+ _UIntCodes,
+ _ULongLongCodes,
+]
+_DTypeLikeInt = Union[
+ Type[int],
+ Type[np.signedinteger],
+ "np.dtype[np.signedinteger]",
+ "_SupportsDType[np.dtype[np.signedinteger]]",
+ _Int8Codes,
+ _Int16Codes,
+ _Int32Codes,
+ _Int64Codes,
+ _ByteCodes,
+ _ShortCodes,
+ _IntCCodes,
+ _IntPCodes,
+ _IntCodes,
+ _LongLongCodes,
+]
+_DTypeLikeFloat = Union[
+ Type[float],
+ Type[np.floating],
+ "np.dtype[np.floating]",
+ "_SupportsDType[np.dtype[np.floating]]",
+ _Float16Codes,
+ _Float32Codes,
+ _Float64Codes,
+ _HalfCodes,
+ _SingleCodes,
+ _DoubleCodes,
+ _LongDoubleCodes,
+]
+_DTypeLikeComplex = Union[
+ Type[complex],
+ Type[np.complexfloating],
+ "np.dtype[np.complexfloating]",
+ "_SupportsDType[np.dtype[np.complexfloating]]",
+ _Complex64Codes,
+ _Complex128Codes,
+ _CSingleCodes,
+ _CDoubleCodes,
+ _CLongDoubleCodes,
+]
+_DTypeLikeDT64 = Union[
+ Type[np.timedelta64],
+ "np.dtype[np.timedelta64]",
+ "_SupportsDType[np.dtype[np.timedelta64]]",
+ _TD64Codes,
+]
+_DTypeLikeTD64 = Union[
+ Type[np.datetime64],
+ "np.dtype[np.datetime64]",
+ "_SupportsDType[np.dtype[np.datetime64]]",
+ _DT64Codes,
+]
+_DTypeLikeStr = Union[
+ Type[str],
+ Type[np.str_],
+ "np.dtype[np.str_]",
+ "_SupportsDType[np.dtype[np.str_]]",
+ _StrCodes,
+]
+_DTypeLikeBytes = Union[
+ Type[bytes],
+ Type[np.bytes_],
+ "np.dtype[np.bytes_]",
+ "_SupportsDType[np.dtype[np.bytes_]]",
+ _BytesCodes,
+]
+_DTypeLikeVoid = Union[
+ Type[np.void],
+ "np.dtype[np.void]",
+ "_SupportsDType[np.dtype[np.void]]",
+ _VoidCodes,
+ _VoidDTypeLike,
+]
+_DTypeLikeObject = Union[
+ type,
+ "np.dtype[np.object_]",
+ "_SupportsDType[np.dtype[np.object_]]",
+ _ObjectCodes,
+]
diff --git a/numpy/typing/tests/data/fail/dtype.py b/numpy/typing/tests/data/fail/dtype.py
index 7d4783d8f..7d419a1d1 100644
--- a/numpy/typing/tests/data/fail/dtype.py
+++ b/numpy/typing/tests/data/fail/dtype.py
@@ -1,10 +1,16 @@
import numpy as np
-class Test:
- not_dtype = float
+class Test1:
+ not_dtype = np.dtype(float)
-np.dtype(Test()) # E: No overload variant of "dtype" matches
+
+class Test2:
+ dtype = float
+
+
+np.dtype(Test1()) # E: No overload variant of "dtype" matches
+np.dtype(Test2()) # E: incompatible type
np.dtype( # E: No overload variant of "dtype" matches
{