From e4a11cb29af12b0e67fda3f924585b23cee0b567 Mon Sep 17 00:00:00 2001 From: Sista Seetaram Date: Fri, 3 Sep 2021 00:25:57 +0530 Subject: fixed unhashable instance and potential exception as listed in LGTM#19077 --- numpy/array_api/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index d530a91ae..8b7832d74 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -37,7 +37,7 @@ NestedSequence = Sequence[Sequence[Any]] Device = Any Dtype = Type[ - Union[[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]] + Union[(int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64)] ] SupportsDLPack = Any SupportsBufferProtocol = Any -- cgit v1.2.1 From e4c589054b342dc750cdd09686c0c01ab762679c Mon Sep 17 00:00:00 2001 From: Sista Seetaram Date: Fri, 3 Sep 2021 06:09:38 +0530 Subject: fix unhashable instance and potential exception identified by LGTM --- numpy/array_api/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 8b7832d74..831a108bc 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -37,7 +37,7 @@ NestedSequence = Sequence[Sequence[Any]] Device = Any Dtype = Type[ - Union[(int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64)] + Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64] ] SupportsDLPack = Any SupportsBufferProtocol = Any -- cgit v1.2.1 From 2d112a98ed7597c4120b31908384ae09b0304659 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Sat, 25 Sep 2021 17:34:22 -0500 Subject: ENH: Updates to numpy.array_api (#19937) * Add __index__ to array_api and update __int__, __bool__, and __float__ The spec specifies that they should only work on arrays with corresponding dtypes. __index__ is new in the spec since the initial PR, and works identically to np.array.__index__. * Add the to_device method to the array_api This method is new since #18585. It does nothing in NumPy since NumPy does not support non-CPU devices. * Update transpose methods in the array_api transpose() was renamed to matrix_transpose() and now operates on stacks of matrices. A function to permute dimensions will be added once it is finalized in the spec. The attribute mT was added and the T attribute was updated to only operate on 2-dimensional arrays as per the spec. * Restrict input dtypes in the array API statistical functions * Add the dtype parameter to the array API sum() and prod() * Add the function permute_dims() to the array_api namespace permute_dims() is the replacement for transpose(), which was split into permute_dims() and matrix_transpose(). * Add tril and triu to the array API namespace * Fix the array_api Array.__repr__ to indent the array properly * Make the Device type in the array_api just accept the string "cpu" --- numpy/array_api/__init__.py | 11 ++++++--- numpy/array_api/_array_object.py | 34 ++++++++++++++++++++++++++- numpy/array_api/_creation_functions.py | 30 +++++++++++++++++++++++- numpy/array_api/_linear_algebra_functions.py | 13 +++++------ numpy/array_api/_manipulation_functions.py | 11 +++++++++ numpy/array_api/_statistical_functions.py | 35 +++++++++++++++++++++++++++- numpy/array_api/_typing.py | 4 ++-- numpy/array_api/tests/test_array_object.py | 30 +++++++++++++++++++----- 8 files changed, 147 insertions(+), 21 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 790157504..d8b29057e 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -143,6 +143,8 @@ from ._creation_functions import ( meshgrid, ones, ones_like, + tril, + triu, zeros, zeros_like, ) @@ -160,6 +162,8 @@ __all__ += [ "meshgrid", "ones", "ones_like", + "tril", + "triu", "zeros", "zeros_like", ] @@ -333,21 +337,22 @@ __all__ += [ # from ._linear_algebra_functions import einsum # __all__ += ['einsum'] -from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot +from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot -__all__ += ["matmul", "tensordot", "transpose", "vecdot"] +__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] from ._manipulation_functions import ( concat, expand_dims, flip, + permute_dims, reshape, roll, squeeze, stack, ) -__all__ += ["concat", "expand_dims", "flip", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"] from ._searching_functions import argmax, argmin, nonzero, where diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 2d746e78b..830319e8c 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -99,7 +99,10 @@ class Array: """ Performs the operation __repr__. """ - return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})" + prefix = "Array(" + suffix = f", dtype={self.dtype.name})" + mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) + return prefix + mid + suffix # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than @@ -391,6 +394,8 @@ class Array: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("bool is only allowed on arrays with 0 dimensions") + if self.dtype not in _boolean_dtypes: + raise ValueError("bool is only allowed on boolean arrays") res = self._array.__bool__() return res @@ -429,6 +434,8 @@ class Array: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("float is only allowed on arrays with 0 dimensions") + if self.dtype not in _floating_dtypes: + raise ValueError("float is only allowed on floating-point arrays") res = self._array.__float__() return res @@ -488,9 +495,18 @@ class Array: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("int is only allowed on arrays with 0 dimensions") + if self.dtype not in _integer_dtypes: + raise ValueError("int is only allowed on integer arrays") res = self._array.__int__() return res + def __index__(self: Array, /) -> int: + """ + Performs the operation __index__. + """ + res = self._array.__index__() + return res + def __invert__(self: Array, /) -> Array: """ Performs the operation __invert__. @@ -979,6 +995,11 @@ class Array: res = self._array.__rxor__(other._array) return self.__class__._new(res) + def to_device(self: Array, device: Device, /) -> Array: + if device == 'cpu': + return self + raise ValueError(f"Unsupported device {device!r}") + @property def dtype(self) -> Dtype: """ @@ -992,6 +1013,12 @@ class Array: def device(self) -> Device: return "cpu" + # Note: mT is new in array API spec (see matrix_transpose) + @property + def mT(self) -> Array: + from ._linear_algebra_functions import matrix_transpose + return matrix_transpose(self) + @property def ndim(self) -> int: """ @@ -1026,4 +1053,9 @@ class Array: See its docstring for more information. """ + # Note: T only works on 2-dimensional arrays. See the corresponding + # note in the specification: + # https://data-apis.org/array-api/latest/API_specification/array_object.html#t + if self.ndim != 2: + raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.") return self._array.T diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index e9c01e7e6..9f8136267 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -22,7 +22,7 @@ def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. # We use this instead of "dtype in _all_dtypes" because the dtype objects - # define equality with the sorts of things we want to disallw. + # define equality with the sorts of things we want to disallow. for d in (None,) + _all_dtypes: if dtype is d: return @@ -281,6 +281,34 @@ def ones_like( return Array._new(np.ones_like(x._array, dtype=dtype)) +def tril(x: Array, /, *, k: int = 0) -> Array: + """ + Array API compatible wrapper for :py:func:`np.tril `. + + See its docstring for more information. + """ + from ._array_object import Array + + if x.ndim < 2: + # Note: Unlike np.tril, x must be at least 2-D + raise ValueError("x must be at least 2-dimensional for tril") + return Array._new(np.tril(x._array, k=k)) + + +def triu(x: Array, /, *, k: int = 0) -> Array: + """ + Array API compatible wrapper for :py:func:`np.triu `. + + See its docstring for more information. + """ + from ._array_object import Array + + if x.ndim < 2: + # Note: Unlike np.triu, x must be at least 2-D + raise ValueError("x must be at least 2-dimensional for triu") + return Array._new(np.triu(x._array, k=k)) + + def zeros( shape: Union[int, Tuple[int, ...]], *, diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py index 089081725..7a6c9846c 100644 --- a/numpy/array_api/_linear_algebra_functions.py +++ b/numpy/array_api/_linear_algebra_functions.py @@ -52,13 +52,12 @@ def tensordot( return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) -def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: - """ - Array API compatible wrapper for :py:func:`np.transpose `. - - See its docstring for more information. - """ - return Array._new(np.transpose(x._array, axes=axes)) +# Note: this function is new in the array API spec. Unlike transpose, it only +# transposes the last two axes. +def matrix_transpose(x: Array, /) -> Array: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return Array._new(np.swapaxes(x._array, -1, -2)) # Note: vecdot is not in NumPy diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index c11866261..4f2114ff5 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -41,6 +41,17 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> return Array._new(np.flip(x._array, axis=axis)) +# Note: The function name is different here (see also matrix_transpose). +# Unlike transpose(), the axes argument is required. +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: + """ + Array API compatible wrapper for :py:func:`np.transpose `. + + See its docstring for more information. + """ + return Array._new(np.transpose(x._array, axes)) + + def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape `. diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index 63790b447..c5abf9468 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -1,8 +1,17 @@ from __future__ import annotations +from ._dtypes import ( + _floating_dtypes, + _numeric_dtypes, +) from ._array_object import Array +from ._creation_functions import asarray +from ._dtypes import float32, float64 -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union + +if TYPE_CHECKING: + from ._typing import Dtype import numpy as np @@ -14,6 +23,8 @@ def max( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in max") return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) @@ -24,6 +35,8 @@ def mean( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in mean") return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) @@ -34,6 +47,8 @@ def min( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in min") return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) @@ -42,8 +57,15 @@ def prod( /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in prod") + # Note: sum() and prod() always upcast float32 to float64 for dtype=None + # We need to do so here before computing the product to avoid overflow + if dtype is None and x.dtype == float32: + x = asarray(x, dtype=float64) return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims)) @@ -56,6 +78,8 @@ def std( keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in std") return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) @@ -64,8 +88,15 @@ def sum( /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, keepdims: bool = False, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in sum") + # Note: sum() and prod() always upcast float32 to float64 for dtype=None + # We need to do so here before summing to avoid overflow + if dtype is None and x.dtype == float32: + x = asarray(x, dtype=float64) return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims)) @@ -78,4 +109,6 @@ def var( keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in var") return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 831a108bc..5f937a56c 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -15,7 +15,7 @@ __all__ = [ "PyCapsule", ] -from typing import Any, Sequence, Type, Union +from typing import Any, Literal, Sequence, Type, Union from . import ( Array, @@ -35,7 +35,7 @@ from . import ( # similar comment in numpy/typing/_array_like.py NestedSequence = Sequence[Sequence[Any]] -Device = Any +Device = Literal["cpu"] Dtype = Type[ Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64] ] diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 088e09b9f..7959f92b4 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -1,3 +1,5 @@ +import operator + from numpy.testing import assert_raises import numpy as np @@ -255,15 +257,31 @@ def test_operators(): def test_python_scalar_construtors(): - a = asarray(False) - b = asarray(0) - c = asarray(0.0) + b = asarray(False) + i = asarray(0) + f = asarray(0.0) - assert bool(a) == bool(b) == bool(c) == False - assert int(a) == int(b) == int(c) == 0 - assert float(a) == float(b) == float(c) == 0.0 + assert bool(b) == False + assert int(i) == 0 + assert float(f) == 0.0 + assert operator.index(i) == 0 # bool/int/float should only be allowed on 0-D arrays. assert_raises(TypeError, lambda: bool(asarray([False]))) assert_raises(TypeError, lambda: int(asarray([0]))) assert_raises(TypeError, lambda: float(asarray([0.0]))) + assert_raises(TypeError, lambda: operator.index(asarray([0]))) + + # bool/int/float should only be allowed on arrays of the corresponding + # dtype + assert_raises(ValueError, lambda: bool(i)) + assert_raises(ValueError, lambda: bool(f)) + + assert_raises(ValueError, lambda: int(b)) + assert_raises(ValueError, lambda: int(f)) + + assert_raises(ValueError, lambda: float(b)) + assert_raises(ValueError, lambda: float(i)) + + assert_raises(TypeError, lambda: operator.index(b)) + assert_raises(TypeError, lambda: operator.index(f)) -- cgit v1.2.1 From a5beccfa3574f4fcb1b6030737b728e65803791f Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:43:16 +0200 Subject: MAINT: Fix invalid parameter types used in `Dtype` --- numpy/array_api/_typing.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 5f937a56c..f66279fbc 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -15,10 +15,12 @@ __all__ = [ "PyCapsule", ] -from typing import Any, Literal, Sequence, Type, Union +import sys +from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING -from . import ( - Array, +from . import Array +from numpy import ( + dtype, int8, int16, int32, @@ -36,9 +38,22 @@ from . import ( NestedSequence = Sequence[Sequence[Any]] Device = Literal["cpu"] -Dtype = Type[ - Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64] -] +if TYPE_CHECKING or sys.version_info >= (3, 9): + Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + ]] +else: + Dtype = dtype + SupportsDLPack = Any SupportsBufferProtocol = Any PyCapsule = Any -- cgit v1.2.1 From 4f7e991960c24fc9548f8f3d6d5f8967c2ece84a Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:44:02 +0200 Subject: MAINT: Add a missing subscription slot to `NestedSequence` --- numpy/array_api/_typing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index f66279fbc..4785f5fe3 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -16,7 +16,7 @@ __all__ = [ ] import sys -from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING +from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar from . import Array from numpy import ( @@ -35,7 +35,8 @@ from numpy import ( # This should really be recursive, but that isn't supported yet. See the # similar comment in numpy/typing/_array_like.py -NestedSequence = Sequence[Sequence[Any]] +_T = TypeVar("_T") +NestedSequence = Sequence[Sequence[_T]] Device = Literal["cpu"] if TYPE_CHECKING or sys.version_info >= (3, 9): -- cgit v1.2.1 From dc69553ef179b6713d85ec747e6d030dd7087f05 Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:44:46 +0200 Subject: MAINT: Remove the `Sequence` encapsulation from variadic arguments --- numpy/array_api/_creation_functions.py | 2 +- numpy/array_api/_data_type_functions.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index 9f8136267..8a23cd88e 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -232,7 +232,7 @@ def linspace( return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) -def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]: +def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: """ Array API compatible wrapper for :py:func:`np.meshgrid `. diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index fd92aa250..7ccbe9469 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: import numpy as np -def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> List[Array]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays `. @@ -98,7 +98,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: return iinfo_object(ii.bits, ii.max, ii.min) -def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: +def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: """ Array API compatible wrapper for :py:func:`np.result_type `. -- cgit v1.2.1 From 4f8f50d5b5e992afaf0ef08773bd88d696683bd3 Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Mon, 27 Sep 2021 14:47:22 +0200 Subject: MAINT: Import `Array` from the `_array_object` namespace Changed as `Array` does not live in the main `np.array_api` namespace --- numpy/array_api/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 4785f5fe3..519e8463c 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -18,7 +18,7 @@ __all__ = [ import sys from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar -from . import Array +from ._array_object import Array from numpy import ( dtype, int8, -- cgit v1.2.1 From 9c356d55d78620532ed92987af1721d7ca4034ec Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Wed, 29 Sep 2021 22:52:14 +0200 Subject: Disallow `k=None` for the `eye` function --- numpy/array_api/_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index 8a23cd88e..e36807468 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -134,7 +134,7 @@ def eye( n_cols: Optional[int] = None, /, *, - k: Optional[int] = 0, + k: int = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> Array: -- cgit v1.2.1 From 4d23ebeb068c8d6ba6edfc11d32ab2af8bb89c74 Mon Sep 17 00:00:00 2001 From: Alessia Marcolini <98marcolini@gmail.com> Date: Fri, 8 Oct 2021 09:49:11 +0000 Subject: MAINT: remove unused imports --- numpy/array_api/tests/test_creation_functions.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 3cb8865cd..7b633eaf1 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -8,30 +8,15 @@ from .._creation_functions import ( empty, empty_like, eye, - from_dlpack, full, full_like, linspace, - meshgrid, ones, ones_like, zeros, zeros_like, ) from .._array_object import Array -from .._dtypes import ( - _all_dtypes, - _boolean_dtypes, - _floating_dtypes, - _integer_dtypes, - _integer_or_boolean_dtypes, - _numeric_dtypes, - int8, - int16, - int32, - int64, - uint64, -) def test_asarray_errors(): -- cgit v1.2.1 From f931a434839222bb00282a432d6d6a0c2c52eb7d Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:10:47 +0200 Subject: ENH: Replace `NestedSequence` with a proper nested sequence protocol --- numpy/array_api/_typing.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 519e8463c..5e980b16f 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only valid for inputs that match the given type annotations. """ +from __future__ import annotations + __all__ = [ "Array", "Device", @@ -16,7 +18,16 @@ __all__ = [ ] import sys -from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar +from typing import ( + Any, + Literal, + Sequence, + Type, + Union, + TYPE_CHECKING, + TypeVar, + Protocol, +) from ._array_object import Array from numpy import ( @@ -33,10 +44,11 @@ from numpy import ( float64, ) -# This should really be recursive, but that isn't supported yet. See the -# similar comment in numpy/typing/_array_like.py -_T = TypeVar("_T") -NestedSequence = Sequence[Sequence[_T]] +_T_co = TypeVar("_T_co", covariant=True) + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... + def __len__(self, /) -> int: ... Device = Literal["cpu"] if TYPE_CHECKING or sys.version_info >= (3, 9): -- cgit v1.2.1 From 3952e8f1390629078fdb229236b3b1ce40140c32 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:11:06 +0200 Subject: ENH: Change `SupportsDLPack` into a protocol --- numpy/array_api/_typing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 5e980b16f..dfa87b358 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -67,6 +67,8 @@ if TYPE_CHECKING or sys.version_info >= (3, 9): else: Dtype = dtype -SupportsDLPack = Any SupportsBufferProtocol = Any PyCapsule = Any + +class SupportsDLPack(Protocol): + def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... -- cgit v1.2.1 From d74bea12d19dd92c9cf07cac35e94d45fb331832 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Mon, 18 Oct 2021 18:13:39 +0200 Subject: MAINT: Replace the `__array_namespace__` return type with `Any` Replace `object` as it cannot be used for expressing the objects in the array namespace. --- numpy/array_api/_array_object.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 830319e8c..ef66c5efd 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -29,7 +29,7 @@ from ._dtypes import ( _dtype_categories, ) -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union, Any if TYPE_CHECKING: from ._typing import PyCapsule, Device, Dtype @@ -382,7 +382,7 @@ class Array: def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None - ) -> object: + ) -> Any: if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") return array_api -- cgit v1.2.1