From a233689a9837a4aeb71efe144eae578aa22ca58f Mon Sep 17 00:00:00 2001 From: Antti Kaihola Date: Wed, 12 Oct 2016 23:17:23 +0300 Subject: BUG: Make assert_allclose(..., equal_nan=False) work. As discussed in my comments for issue #8145, this patch adds the equal_nan argument to assert_array_compare(), and assert_allclose() passes the value it receives for the same argument through to assert_array_compare(). Closes #8142. --- numpy/testing/utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) (limited to 'numpy/testing/utils.py') diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 599e73cb0..b01de173d 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -666,7 +666,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): def assert_array_compare(comparison, x, y, err_msg='', verbose=True, - header='', precision=6): + header='', precision=6, equal_nan=True): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, isnan, isinf, any, all, inf x = array(x, copy=False, subok=True) @@ -724,21 +724,25 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, raise AssertionError(msg) if isnumber(x) and isnumber(y): - x_isnan, y_isnan = isnan(x), isnan(y) + if equal_nan: + x_isnan, y_isnan = isnan(x), isnan(y) + # Validate that NaNs are in the same place + if any(x_isnan) or any(y_isnan): + chk_same_position(x_isnan, y_isnan, hasval='nan') + x_isinf, y_isinf = isinf(x), isinf(y) - # Validate that the special values are in the same place - if any(x_isnan) or any(y_isnan): - chk_same_position(x_isnan, y_isnan, hasval='nan') + # Validate that infinite values are in the same place if any(x_isinf) or any(y_isinf): # Check +inf and -inf separately, since they are different chk_same_position(x == +inf, y == +inf, hasval='+inf') chk_same_position(x == -inf, y == -inf, hasval='-inf') # Combine all the special values - x_id, y_id = x_isnan, y_isnan - x_id |= x_isinf - y_id |= y_isinf + x_id, y_id = x_isinf, y_isinf + if equal_nan: + x_id |= x_isnan + y_id |= y_isnan # Only do the comparison if actual values are left if all(x_id): @@ -1381,7 +1385,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=False, actual, desired = np.asanyarray(actual), np.asanyarray(desired) header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol) assert_array_compare(compare, actual, desired, err_msg=str(err_msg), - verbose=verbose, header=header) + verbose=verbose, header=header, equal_nan=equal_nan) def assert_array_almost_equal_nulp(x, y, nulp=1): -- cgit v1.2.1