From bc20d334b575f897157b1cf3eecda77f3e40e049 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 20:01:11 -0600 Subject: Move the array API dtype categories into the top level They are not an official part of the spec but are useful for various parts of the implementation. --- numpy/array_api/_array_object.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) (limited to 'numpy/array_api/_array_object.py') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index af70058e6..50906642d 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -98,23 +98,14 @@ class Array: if other is NotImplemented: return other """ - from ._dtypes import _result_type - - _dtypes = { - 'all': _all_dtypes, - 'numeric': _numeric_dtypes, - 'integer': _integer_dtypes, - 'integer or boolean': _integer_or_boolean_dtypes, - 'boolean': _boolean_dtypes, - 'floating-point': _floating_dtypes, - } - - if self.dtype not in _dtypes[dtype_category]: + from ._dtypes import _result_type, _dtype_categories + + if self.dtype not in _dtype_categories[dtype_category]: raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) elif isinstance(other, Array): - if other.dtype not in _dtypes[dtype_category]: + if other.dtype not in _dtype_categories[dtype_category]: raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') else: return NotImplemented -- cgit v1.2.1