From 45dbdc9d8fd3fa7fbabbddf690f6c892cc24aa1d Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Fri, 3 Sep 2021 15:10:23 +0530 Subject: CopyMode added to np.array_api --- numpy/array_api/_creation_functions.py | 2 +- numpy/array_api/tests/test_creation_functions.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index e9c01e7e6..c15d54db1 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -43,7 +43,7 @@ def asarray( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - copy: Optional[bool] = None, + copy: Optional[Union[bool | np._CopyMode]] = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray `. diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 3cb8865cd..b0c99cd45 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -56,12 +56,17 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 1) assert all(a[0] == 0) - # Once copy=False is implemented, replace this with - # a = asarray([1]) - # b = asarray(a, copy=False) - # a[0] = 0 - # assert all(b[0] == 0) + a = asarray([1]) + b = asarray(a, copy=np._CopyMode.ALWAYS) + a[0] = 0 + assert all(b[0] == 1) + assert all(a[0] == 0) + a = asarray([1]) + b = asarray(a, copy=np._CopyMode.NEVER) + a[0] = 0 + assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) + assert_raises(NotImplementedError, lambda: asarray(a, copy=np._CopyMode.NEVER)) def test_arange_errors(): -- cgit v1.2.1 From a39312cf4eca9dc9573737a01f16fee24efb6c0a Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Fri, 3 Sep 2021 15:15:18 +0530 Subject: fixed linting issues --- numpy/array_api/tests/test_creation_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index b0c99cd45..9d2909e91 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -66,7 +66,8 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) - assert_raises(NotImplementedError, lambda: asarray(a, copy=np._CopyMode.NEVER)) + assert_raises(NotImplementedError, lambda: asarray(a, + copy=np._CopyMode.NEVER)) def test_arange_errors(): -- cgit v1.2.1 From c2acd5b25a04783fbbe3ba32426039e4dbe9207e Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Fri, 3 Sep 2021 15:17:58 +0530 Subject: fixed linting issues --- numpy/array_api/tests/test_creation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 9d2909e91..7182209dc 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -66,8 +66,8 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) - assert_raises(NotImplementedError, lambda: asarray(a, - copy=np._CopyMode.NEVER)) + assert_raises(NotImplementedError, + lambda: asarray(a, copy=np._CopyMode.NEVER)) def test_arange_errors(): -- cgit v1.2.1 From 56647dd47345a7fd24b4ee8d9d52025fcdc3b9ae Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Sat, 4 Sep 2021 22:33:52 +0530 Subject: Addressed reviews --- numpy/array_api/_creation_functions.py | 6 +++--- numpy/array_api/tests/test_creation_functions.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index c15d54db1..2d6cf4414 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -43,7 +43,7 @@ def asarray( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - copy: Optional[Union[bool | np._CopyMode]] = None, + copy: Optional[Union[bool, np._CopyMode]] = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray `. @@ -57,11 +57,11 @@ def asarray( _check_valid_dtype(dtype) if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") - if copy is False: + if copy in (False, np._CopyMode.IF_NEEDED): # Note: copy=False is not yet implemented in np.asarray raise NotImplementedError("copy=False is not yet implemented") if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): - if copy is True: + if copy in (True, np._CopyMode.ALWAYS): return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 7182209dc..2ee23a47b 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -67,7 +67,7 @@ def test_asarray_copy(): assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) assert_raises(NotImplementedError, - lambda: asarray(a, copy=np._CopyMode.NEVER)) + lambda: asarray(a, copy=np._CopyMode.IF_NEEDED)) def test_arange_errors(): -- cgit v1.2.1