summaryrefslogtreecommitdiff
path: root/numpy/testing
diff options
context:
space:
mode:
authorJon Morris <jontwo@users.noreply.github.com>2022-06-24 16:51:36 +0100
committerGitHub <noreply@github.com>2022-06-24 11:51:36 -0400
commitcafec60a5e28af98fb8798049edd7942720d2d74 (patch)
tree00d6ea28360e1a59972d07de162bd97e782b2b24 /numpy/testing
parent019c8c9b2a7c084eb01cf4d8569799a5537d884d (diff)
downloadnumpy-cafec60a5e28af98fb8798049edd7942720d2d74.tar.gz
ENH: Add strict parameter to assert_array_equal. (#21595)
Fixes #9542 Co-authored-by: Bas van Beek <43369155+BvB93@users.noreply.github.com>
Diffstat (limited to 'numpy/testing')
-rw-r--r--numpy/testing/_private/utils.py52
-rw-r--r--numpy/testing/_private/utils.pyi4
-rw-r--r--numpy/testing/tests/test_utils.py37
3 files changed, 87 insertions, 6 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index e4f8b9892..f60ca6922 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -699,7 +699,8 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
- precision=6, equal_nan=True, equal_inf=True):
+ precision=6, equal_nan=True, equal_inf=True,
+ *, strict=False):
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
@@ -753,11 +754,18 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
return y_id
try:
- cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
+ if strict:
+ cond = x.shape == y.shape and x.dtype == y.dtype
+ else:
+ cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
if not cond:
+ if x.shape != y.shape:
+ reason = f'\n(shapes {x.shape}, {y.shape} mismatch)'
+ else:
+ reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)'
msg = build_err_msg([x, y],
err_msg
- + f'\n(shapes {x.shape}, {y.shape} mismatch)',
+ + reason,
verbose=verbose, header=header,
names=('x', 'y'), precision=precision)
raise AssertionError(msg)
@@ -852,7 +860,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
raise ValueError(msg)
-def assert_array_equal(x, y, err_msg='', verbose=True):
+def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False):
"""
Raises an AssertionError if two array_like objects are not equal.
@@ -876,6 +884,10 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
+ strict : bool, optional
+ If True, raise an AssertionError when either the shape or the data
+ type of the array_like objects does not match. The special
+ handling for scalars mentioned in the Notes section is disabled.
Raises
------
@@ -892,7 +904,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
-----
When one of `x` and `y` is a scalar and the other is array_like, the
function checks that each element of the array_like object is equal to
- the scalar.
+ the scalar. This behaviour can be disabled with the `strict` parameter.
Examples
--------
@@ -929,10 +941,38 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
>>> x = np.full((2, 5), fill_value=3)
>>> np.testing.assert_array_equal(x, 3)
+ Use `strict` to raise an AssertionError when comparing a scalar with an
+ array:
+
+ >>> np.testing.assert_array_equal(x, 3, strict=True)
+ Traceback (most recent call last):
+ ...
+ AssertionError:
+ Arrays are not equal
+ <BLANKLINE>
+ (shapes (2, 5), () mismatch)
+ x: array([[3, 3, 3, 3, 3],
+ [3, 3, 3, 3, 3]])
+ y: array(3)
+
+ The `strict` parameter also ensures that the array data types match:
+
+ >>> x = np.array([2, 2, 2])
+ >>> y = np.array([2., 2., 2.], dtype=np.float32)
+ >>> np.testing.assert_array_equal(x, y, strict=True)
+ Traceback (most recent call last):
+ ...
+ AssertionError:
+ Arrays are not equal
+ <BLANKLINE>
+ (dtypes int64, float32 mismatch)
+ x: array([2, 2, 2])
+ y: array([2., 2., 2.], dtype=float32)
"""
__tracebackhide__ = True # Hide traceback for py.test
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
- verbose=verbose, header='Arrays are not equal')
+ verbose=verbose, header='Arrays are not equal',
+ strict=strict)
def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi
index 0be13b729..6e051e914 100644
--- a/numpy/testing/_private/utils.pyi
+++ b/numpy/testing/_private/utils.pyi
@@ -200,6 +200,8 @@ def assert_array_compare(
precision: SupportsIndex = ...,
equal_nan: bool = ...,
equal_inf: bool = ...,
+ *,
+ strict: bool = ...
) -> None: ...
def assert_array_equal(
@@ -207,6 +209,8 @@ def assert_array_equal(
y: ArrayLike,
err_msg: str = ...,
verbose: bool = ...,
+ *,
+ strict: bool = ...
) -> None: ...
def assert_array_almost_equal(
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 49eeecc8e..c82343f0c 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -214,6 +214,43 @@ class TestArrayEqual(_GenericTest):
np.array([1, 2, 3], np.float32),
np.array([1, 1e-40, 3], np.float32))
+ def test_array_vs_scalar_is_equal(self):
+ """Test comparing an array with a scalar when all values are equal."""
+ a = np.array([1., 1., 1.])
+ b = 1.
+
+ self._test_equal(a, b)
+
+ def test_array_vs_scalar_not_equal(self):
+ """Test comparing an array with a scalar when not all values equal."""
+ a = np.array([1., 2., 3.])
+ b = 1.
+
+ self._test_not_equal(a, b)
+
+ def test_array_vs_scalar_strict(self):
+ """Test comparing an array with a scalar with strict option."""
+ a = np.array([1., 1., 1.])
+ b = 1.
+
+ with pytest.raises(AssertionError):
+ assert_array_equal(a, b, strict=True)
+
+ def test_array_vs_array_strict(self):
+ """Test comparing two arrays with strict option."""
+ a = np.array([1., 1., 1.])
+ b = np.array([1., 1., 1.])
+
+ assert_array_equal(a, b, strict=True)
+
+ def test_array_vs_float_array_strict(self):
+ """Test comparing two arrays with strict option."""
+ a = np.array([1, 1, 1])
+ b = np.array([1., 1., 1.])
+
+ with pytest.raises(AssertionError):
+ assert_array_equal(a, b, strict=True)
+
class TestBuildErrorMessage: