summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_data_type_functions.py5
-rw-r--r--numpy/array_api/tests/test_data_type_functions.py2
2 files changed, 5 insertions, 2 deletions
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index 1198ff778..1fb6062f6 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -56,10 +56,15 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
raise TypeError(f"{from_=}, but should be an array_api array or dtype")
if to not in _all_dtypes:
raise TypeError(f"{to=}, but should be a dtype")
+ # Note: We avoid np.can_cast() as it has discrepancies with the array API.
+ # See https://github.com/numpy/numpy/issues/20870
try:
+ # We promote `from_` and `to` together. We then check if the promoted
+ # dtype is `to`, which indicates if `from_` can (up)cast to `to`.
dtype = _result_type(from_, to)
return to == dtype
except TypeError:
+ # _result_type() raises if the dtypes don't promote together
return False
diff --git a/numpy/array_api/tests/test_data_type_functions.py b/numpy/array_api/tests/test_data_type_functions.py
index 3f01bb311..efe3d0abd 100644
--- a/numpy/array_api/tests/test_data_type_functions.py
+++ b/numpy/array_api/tests/test_data_type_functions.py
@@ -8,8 +8,6 @@ from numpy import array_api as xp
[
(xp.int8, xp.int16, True),
(xp.int16, xp.int8, False),
- # np.can_cast has discrepancies with the Array API
- # See https://github.com/numpy/numpy/issues/20870
(xp.bool, xp.int8, False),
(xp.asarray(0, dtype=xp.uint8), xp.int8, False),
],