diff options
author | Ralf Gommers <ralf.gommers@gmail.com> | 2021-01-29 12:07:01 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-29 12:07:01 +0000 |
commit | 9772bea95f3d079c2b59107ddd9731602b5edde0 (patch) | |
tree | 4738e7abad35b274431445bad3cf9a9a9181737c | |
parent | b969f8c9a2178bec235d6c47db3e7cd68c5d39c1 (diff) | |
parent | d788788f8233d96e647c64bd4070305146b9ff92 (diff) | |
download | numpy-9772bea95f3d079c2b59107ddd9731602b5edde0.tar.gz |
Merge pull request #18236 from BvB93/dtype-like
ENH: Add aliases for commonly used dtype-like objects
-rw-r--r-- | numpy/__init__.pyi | 5 | ||||
-rw-r--r-- | numpy/typing/__init__.py | 17 | ||||
-rw-r--r-- | numpy/typing/_dtype_like.py | 161 | ||||
-rw-r--r-- | numpy/typing/tests/data/fail/dtype.py | 12 |
4 files changed, 182 insertions, 13 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 911b496df..fe9dc5914 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -682,14 +682,13 @@ class dtype(Generic[_DTypeScalar_co]): align: bool = ..., copy: bool = ..., ) -> dtype[_DTypeScalar_co]: ... - # TODO: handle _SupportsDType better @overload def __new__( cls, - dtype: _SupportsDType, + dtype: _SupportsDType[dtype[_DTypeScalar_co]], align: bool = ..., copy: bool = ..., - ) -> dtype[Any]: ... + ) -> dtype[_DTypeScalar_co]: ... # Handle strings that can't be expressed as literals; i.e. s1, s2, ... @overload def __new__( diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py index 4ec1f4b2f..8147789fb 100644 --- a/numpy/typing/__init__.py +++ b/numpy/typing/__init__.py @@ -297,7 +297,22 @@ from ._scalars import ( _VoidLike_co, ) from ._shape import _Shape, _ShapeLike -from ._dtype_like import _SupportsDType, _VoidDTypeLike, DTypeLike as DTypeLike +from ._dtype_like import ( + DTypeLike as DTypeLike, + _SupportsDType, + _VoidDTypeLike, + _DTypeLikeBool, + _DTypeLikeUInt, + _DTypeLikeInt, + _DTypeLikeFloat, + _DTypeLikeComplex, + _DTypeLikeTD64, + _DTypeLikeDT64, + _DTypeLikeObject, + _DTypeLikeVoid, + _DTypeLikeStr, + _DTypeLikeBytes, +) from ._array_like import ( ArrayLike as ArrayLike, _ArrayLike, diff --git a/numpy/typing/_dtype_like.py b/numpy/typing/_dtype_like.py index 1953bd5fc..f86b4a67c 100644 --- a/numpy/typing/_dtype_like.py +++ b/numpy/typing/_dtype_like.py @@ -1,7 +1,7 @@ import sys -from typing import Any, List, Sequence, Tuple, Union, TYPE_CHECKING +from typing import Any, List, Sequence, Tuple, Union, Type, TypeVar, TYPE_CHECKING -from numpy import dtype +import numpy as np from ._shape import _ShapeLike if sys.version_info >= (3, 8): @@ -15,6 +15,48 @@ else: else: HAVE_PROTOCOL = True +from ._char_codes import ( + _BoolCodes, + _UInt8Codes, + _UInt16Codes, + _UInt32Codes, + _UInt64Codes, + _Int8Codes, + _Int16Codes, + _Int32Codes, + _Int64Codes, + _Float16Codes, + _Float32Codes, + _Float64Codes, + _Complex64Codes, + _Complex128Codes, + _ByteCodes, + _ShortCodes, + _IntCCodes, + _IntPCodes, + _IntCodes, + _LongLongCodes, + _UByteCodes, + _UShortCodes, + _UIntCCodes, + _UIntPCodes, + _UIntCodes, + _ULongLongCodes, + _HalfCodes, + _SingleCodes, + _DoubleCodes, + _LongDoubleCodes, + _CSingleCodes, + _CDoubleCodes, + _CLongDoubleCodes, + _DT64Codes, + _TD64Codes, + _StrCodes, + _BytesCodes, + _VoidCodes, + _ObjectCodes, +) + _DTypeLikeNested = Any # TODO: wait for support for recursive types if TYPE_CHECKING or HAVE_PROTOCOL: @@ -30,9 +72,12 @@ if TYPE_CHECKING or HAVE_PROTOCOL: itemsize: int aligned: bool + _DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype) + # A protocol for anything with the dtype attribute - class _SupportsDType(Protocol): - dtype: _DTypeLikeNested + class _SupportsDType(Protocol[_DType_co]): + @property + def dtype(self) -> _DType_co: ... else: _DTypeDict = Any @@ -61,13 +106,13 @@ _VoidDTypeLike = Union[ # Anything that can be coerced into numpy.dtype. # Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html DTypeLike = Union[ - dtype, + np.dtype, # default data type (float64) None, # array-scalar types and generic types type, # TODO: enumerate these when we add type hints for numpy scalars # anything with a dtype attribute - _SupportsDType, + "_SupportsDType[np.dtype[Any]]", # character codes, type strings or comma-separated fields, e.g., 'float64' str, _VoidDTypeLike, @@ -79,3 +124,107 @@ DTypeLike = Union[ # therefore not included in the Union defining `DTypeLike`. # # See https://github.com/numpy/numpy/issues/16891 for more details. + +# Aliases for commonly used dtype-like objects. +# Note that the precision of `np.number` subclasses is ignored herein. +_DTypeLikeBool = Union[ + Type[bool], + Type[np.bool_], + "np.dtype[np.bool_]", + "_SupportsDType[np.dtype[np.bool_]]", + _BoolCodes, +] +_DTypeLikeUInt = Union[ + Type[np.unsignedinteger], + "np.dtype[np.unsignedinteger]", + "_SupportsDType[np.dtype[np.unsignedinteger]]", + _UInt8Codes, + _UInt16Codes, + _UInt32Codes, + _UInt64Codes, + _UByteCodes, + _UShortCodes, + _UIntCCodes, + _UIntPCodes, + _UIntCodes, + _ULongLongCodes, +] +_DTypeLikeInt = Union[ + Type[int], + Type[np.signedinteger], + "np.dtype[np.signedinteger]", + "_SupportsDType[np.dtype[np.signedinteger]]", + _Int8Codes, + _Int16Codes, + _Int32Codes, + _Int64Codes, + _ByteCodes, + _ShortCodes, + _IntCCodes, + _IntPCodes, + _IntCodes, + _LongLongCodes, +] +_DTypeLikeFloat = Union[ + Type[float], + Type[np.floating], + "np.dtype[np.floating]", + "_SupportsDType[np.dtype[np.floating]]", + _Float16Codes, + _Float32Codes, + _Float64Codes, + _HalfCodes, + _SingleCodes, + _DoubleCodes, + _LongDoubleCodes, +] +_DTypeLikeComplex = Union[ + Type[complex], + Type[np.complexfloating], + "np.dtype[np.complexfloating]", + "_SupportsDType[np.dtype[np.complexfloating]]", + _Complex64Codes, + _Complex128Codes, + _CSingleCodes, + _CDoubleCodes, + _CLongDoubleCodes, +] +_DTypeLikeDT64 = Union[ + Type[np.timedelta64], + "np.dtype[np.timedelta64]", + "_SupportsDType[np.dtype[np.timedelta64]]", + _TD64Codes, +] +_DTypeLikeTD64 = Union[ + Type[np.datetime64], + "np.dtype[np.datetime64]", + "_SupportsDType[np.dtype[np.datetime64]]", + _DT64Codes, +] +_DTypeLikeStr = Union[ + Type[str], + Type[np.str_], + "np.dtype[np.str_]", + "_SupportsDType[np.dtype[np.str_]]", + _StrCodes, +] +_DTypeLikeBytes = Union[ + Type[bytes], + Type[np.bytes_], + "np.dtype[np.bytes_]", + "_SupportsDType[np.dtype[np.bytes_]]", + _BytesCodes, +] +_DTypeLikeVoid = Union[ + Type[np.void], + "np.dtype[np.void]", + "_SupportsDType[np.dtype[np.void]]", + _VoidCodes, + _VoidDTypeLike, +] +_DTypeLikeObject = Union[ + type, + "np.dtype[np.object_]", + "_SupportsDType[np.dtype[np.object_]]", + _ObjectCodes, +] diff --git a/numpy/typing/tests/data/fail/dtype.py b/numpy/typing/tests/data/fail/dtype.py index 7d4783d8f..7d419a1d1 100644 --- a/numpy/typing/tests/data/fail/dtype.py +++ b/numpy/typing/tests/data/fail/dtype.py @@ -1,10 +1,16 @@ import numpy as np -class Test: - not_dtype = float +class Test1: + not_dtype = np.dtype(float) -np.dtype(Test()) # E: No overload variant of "dtype" matches + +class Test2: + dtype = float + + +np.dtype(Test1()) # E: No overload variant of "dtype" matches +np.dtype(Test2()) # E: incompatible type np.dtype( # E: No overload variant of "dtype" matches { |