summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-09-19 23:01:50 -0700
committerEric Wieser <wieser.eric@gmail.com>2017-09-19 23:23:53 -0700
commit4031845251f6fd59da8840698bf36f4c7e974818 (patch)
treec64bd186816d723a8657a7a49e0b23c8fda14aa0 /numpy
parent68fd82271b9ea5a9e50d4e761061dfcca851382a (diff)
downloadnumpy-4031845251f6fd59da8840698bf36f4c7e974818.tar.gz
MAINT: Remove unnecessary special-casing of scalars in isclose
This means that this returns an `np.bool_` instead of a `bool`, but that seems more sensible anyway. The code was likely written this way for when scalar boolean indices didn't work
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/numeric.py16
-rw-r--r--numpy/core/tests/test_numeric.py6
2 files changed, 9 insertions, 13 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index fde08490a..c3c34666f 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -2528,13 +2528,10 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
"""
def within_tol(x, y, atol, rtol):
with errstate(invalid='ignore'):
- result = less_equal(abs(x-y), atol + rtol * abs(y))
- if isscalar(a) and isscalar(b):
- result = bool(result)
- return result
+ return less_equal(abs(x-y), atol + rtol * abs(y))
- x = array(a, copy=False, subok=True, ndmin=1)
- y = array(b, copy=False, subok=True, ndmin=1)
+ x = asanyarray(a)
+ y = asanyarray(b)
# Make sure y is an inexact type to avoid bad behavior on abs(MIN_INT).
# This will cause casting of x later. Also, make sure to allow subclasses
@@ -2561,12 +2558,11 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
if equal_nan:
# Make NaN == NaN
both_nan = isnan(x) & isnan(y)
+
+ # Needed to treat masked arrays correctly. = True would not work.
cond[both_nan] = both_nan[both_nan]
- if isscalar(a) and isscalar(b):
- return bool(cond)
- else:
- return cond
+ return cond[()] # Flatten 0d arrays to scalars
def array_equal(a1, a2):
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index d62e18b93..e8c637179 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1933,9 +1933,9 @@ class TestIsclose(object):
def test_non_finite_scalar(self):
# GH7014, when two scalars are compared the output should also be a
# scalar
- assert_(np.isclose(np.inf, -np.inf) is False)
- assert_(np.isclose(0, np.inf) is False)
- assert_(type(np.isclose(0, np.inf)) is bool)
+ assert_(np.isclose(np.inf, -np.inf) is np.False_)
+ assert_(np.isclose(0, np.inf) is np.False_)
+ assert_(type(np.isclose(0, np.inf)) is np.bool_)
class TestStdVar(object):