From befef7b26773eddd2b656a3ab87f504e6cc173db Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 6 May 2022 09:27:27 +0000 Subject: API: Allow newaxis indexing for `array_api` arrays (#21377) * TST: Add test checking if newaxis indexing works for `array_api` Also removes previous check against newaxis indexing, which is now outdated * TST, BUG: Allow `None` in `array_api` indexing Introduces test for validating flat indexing when `None` is present * MAINT,DOC,TST: Rework of `_validate_index()` in `numpy.array_api` _validate_index() is now called as self._validate_index(shape), and does not return a key. This rework removes the recursive pattern used. Tests are introduced to cover some edge cases. Additionally, its internal docstring reflects new behaviour, and extends the flat indexing note. * MAINT: `advance` -> `advanced` (integer indexing) Co-authored-by: Aaron Meurer * BUG: array_api arrays use internal arrays from array_api array keys When an array_api array is passed as the key for get/setitem, we access the key's internal np.ndarray array to be used as the key for the internal get/setitem operation. This behaviour was initially removed when `_validate_index()` was reworked. * MAINT: Better flat indexing error message for `array_api` arrays Also better semantics for its prior ellipsis count condition Co-authored-by: Sebastian Berg * MAINT: `array_api` arrays don't special case multi-ellipsis errors This gets handled by NumPy-proper. Co-authored-by: Aaron Meurer Co-authored-by: Sebastian Berg --- numpy/array_api/tests/test_array_object.py | 63 +++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 6 deletions(-) (limited to 'numpy/array_api/tests/test_array_object.py') diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 1fe1dfddf..ba9223532 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -2,8 +2,9 @@ import operator from numpy.testing import assert_raises import numpy as np +import pytest -from .. import ones, asarray, result_type, all, equal +from .. import ones, asarray, reshape, result_type, all, equal from .._array_object import Array from .._dtypes import ( _all_dtypes, @@ -17,6 +18,7 @@ from .._dtypes import ( int32, int64, uint64, + bool as bool_, ) @@ -70,11 +72,6 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[[0, 1]]) assert_raises(IndexError, lambda: a[np.array([[0, 1]])]) - # np.newaxis is not allowed - assert_raises(IndexError, lambda: a[None]) - assert_raises(IndexError, lambda: a[None, ...]) - assert_raises(IndexError, lambda: a[..., None]) - # Multiaxis indices must contain exactly as many indices as dimensions assert_raises(IndexError, lambda: a[()]) assert_raises(IndexError, lambda: a[0,]) @@ -322,3 +319,57 @@ def test___array__(): b = np.asarray(a, dtype=np.float64) assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) assert b.dtype == np.float64 + +def test_allow_newaxis(): + a = ones(5) + indexed_a = a[None, :] + assert indexed_a.shape == (1, 5) + +def test_disallow_flat_indexing_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, 0, 0] + +def test_disallow_mask_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, asarray(True)] + +@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) +@pytest.mark.parametrize("index", ["string", False, True]) +def test_error_on_invalid_index(shape, index): + a = ones(shape) + with pytest.raises(IndexError): + a[index] + +def test_mask_0d_array_without_errors(): + a = ones(()) + a[asarray(True)] + +@pytest.mark.parametrize( + "i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])] +) +def test_error_on_invalid_index_with_ellipsis(i): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[..., i] + with pytest.raises(IndexError): + a[i, ...] + +def test_array_keys_use_private_array(): + """ + Indexing operations convert array keys before indexing the internal array + + Fails when array_api array keys are not converted into NumPy-proper arrays + in __getitem__(). This is achieved by passing array_api arrays with 0-sized + dimensions, which NumPy-proper treats erroneously - not sure why! + + TODO: Find and use appropiate __setitem__() case. + """ + a = ones((0, 0), dtype=bool_) + assert a[a].shape == (0,) + + a = ones((0,), dtype=bool_) + key = ones((0, 0), dtype=bool_) + with pytest.raises(IndexError): + a[key] -- cgit v1.2.1