summaryrefslogtreecommitdiff
path: root/numpy/testing/tests
diff options
context:
space:
mode:
authorwtli@Dirac <liwt31@163.com>2018-10-22 22:30:32 +0800
committerwtli@Dirac <liwt31@163.com>2018-10-23 16:29:35 +0800
commitbe5ea7d92d542e7c7eb055c5831a79850f4bfbee (patch)
tree2bb66021c49d699da81d89b331e7610c1cd177b2 /numpy/testing/tests
parentdb5750f6cdc2715f1c65be31f985e2cd2699d2e0 (diff)
downloadnumpy-be5ea7d92d542e7c7eb055c5831a79850f4bfbee.tar.gz
BUG: Fix misleading assert message in assert_almost_equal #12200
Fixes #12200 by making a copy of the matrix before NaN's are excluded. Add a test for it.
Diffstat (limited to 'numpy/testing/tests')
-rw-r--r--numpy/testing/tests/test_utils.py18
1 files changed, 16 insertions, 2 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index e0d3414f7..7e6b18631 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -469,7 +469,8 @@ class TestAlmostEqual(_GenericTest):
self._test_not_equal(x, z)
def test_error_message(self):
- """Check the message is formatted correctly for the decimal value"""
+ """Check the message is formatted correctly for the decimal value.
+ Also check the message when input includes inf or nan (gh12200)"""
x = np.array([1.00000000001, 2.00000000002, 3.00003])
y = np.array([1.00000000002, 2.00000000003, 3.00004])
@@ -493,6 +494,19 @@ class TestAlmostEqual(_GenericTest):
# remove anything that's not the array string
assert_equal(str(e).split('%)\n ')[1], b)
+ # Check the error message when input includes inf or nan
+ x = np.array([np.inf, 0])
+ y = np.array([np.inf, 1])
+ try:
+ self._assert_func(x, y)
+ except AssertionError as e:
+ msgs = str(e).split('\n')
+ # assert error percentage is 50%
+ assert_equal(msgs[3], '(mismatch 50.0%)')
+ # assert output array contains inf
+ assert_equal(msgs[4], ' x: array([inf, 0.])')
+ assert_equal(msgs[5], ' y: array([inf, 1.])')
+
def test_subclass_that_cannot_be_bool(self):
# While we cannot guarantee testing functions will always work for
# subclasses, the tests should ideally rely only on subclasses having
@@ -1077,7 +1091,7 @@ class TestStringEqual(object):
assert_raises(AssertionError,
lambda: assert_string_equal("foo", "hello"))
-
+
def test_regex(self):
assert_string_equal("a+*b", "a+*b")