diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 18 | ||||
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 42 | ||||
-rw-r--r-- | numpy/lib/utils.py | 46 |
3 files changed, 77 insertions, 29 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 352512513..172e9a322 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -3982,23 +3982,7 @@ def _median(a, axis=None, out=None, overwrite_input=False): if np.issubdtype(a.dtype, np.inexact) and sz > 0: # warn and return nans like mean would rout = mean(part[indexer], axis=axis, out=out) - part = np.rollaxis(part, axis, part.ndim) - n = np.isnan(part[..., -1]) - if rout.ndim == 0: - if n == True: - warnings.warn("Invalid value encountered in median", - RuntimeWarning, stacklevel=3) - if out is not None: - out[...] = a.dtype.type(np.nan) - rout = out - else: - rout = a.dtype.type(np.nan) - elif np.count_nonzero(n.ravel()) > 0: - warnings.warn("Invalid value encountered in median for" + - " %d results" % np.count_nonzero(n.ravel()), - RuntimeWarning, stacklevel=3) - rout[n] = np.nan - return rout + return np.lib.utils._median_nancheck(part, rout, axis, out) else: # if there are no nans # Use mean in odd and even case to coerce data type diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index 06c0953b5..18fcb2887 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -693,18 +693,36 @@ class TestNanFunctions_Median(TestCase): def test_float_special(self): with suppress_warnings() as sup: sup.filter(RuntimeWarning) - a = np.array([[np.inf, np.nan], [np.nan, np.nan]]) - assert_equal(np.nanmedian(a, axis=0), [np.inf, np.nan]) - assert_equal(np.nanmedian(a, axis=1), [np.inf, np.nan]) - assert_equal(np.nanmedian(a), np.inf) - - # minimum fill value check - a = np.array([[np.nan, np.nan, np.inf], [np.nan, np.nan, np.inf]]) - assert_equal(np.nanmedian(a, axis=1), np.inf) - - # no mask path - a = np.array([[np.inf, np.inf], [np.inf, np.inf]]) - assert_equal(np.nanmedian(a, axis=1), np.inf) + for inf in [np.inf, -np.inf]: + a = np.array([[inf, np.nan], [np.nan, np.nan]]) + assert_equal(np.nanmedian(a, axis=0), [inf, np.nan]) + assert_equal(np.nanmedian(a, axis=1), [inf, np.nan]) + assert_equal(np.nanmedian(a), inf) + + # minimum fill value check + a = np.array([[np.nan, np.nan, inf], + [np.nan, np.nan, inf]]) + assert_equal(np.nanmedian(a), inf) + assert_equal(np.nanmedian(a, axis=0), [np.nan, np.nan, inf]) + assert_equal(np.nanmedian(a, axis=1), inf) + + # no mask path + a = np.array([[inf, inf], [inf, inf]]) + assert_equal(np.nanmedian(a, axis=1), inf) + + for i in range(0, 10): + for j in range(1, 10): + a = np.array([([np.nan] * i) + ([inf] * j)] * 2) + assert_equal(np.nanmedian(a), inf) + assert_equal(np.nanmedian(a, axis=1), inf) + assert_equal(np.nanmedian(a, axis=0), + ([np.nan] * i) + [inf] * j) + + a = np.array([([np.nan] * i) + ([-inf] * j)] * 2) + assert_equal(np.nanmedian(a), -inf) + assert_equal(np.nanmedian(a, axis=1), -inf) + assert_equal(np.nanmedian(a, axis=0), + ([np.nan] * i) + [-inf] * j) class TestNanFunctions_Percentile(TestCase): diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 5c364268c..61aa5e33b 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -8,6 +8,7 @@ import warnings from numpy.core.numerictypes import issubclass_, issubsctype, issubdtype from numpy.core import ndarray, ufunc, asarray +import numpy as np # getargspec and formatargspec were removed in Python 3.6 from numpy.compat import getargspec, formatargspec @@ -1113,4 +1114,49 @@ def safe_eval(source): import ast return ast.literal_eval(source) + + +def _median_nancheck(data, result, axis, out): + """ + Utility function to check median result from data for NaN values at the end + and return NaN in that case. Input result can also be a MaskedArray. + + Parameters + ---------- + data : array + Input data to median function + result : Array or MaskedArray + Result of median function + axis : {int, sequence of int, None}, optional + Axis or axes along which the median was computed. + out : ndarray, optional + Output array in which to place the result. + Returns + ------- + median : scalar or ndarray + Median or NaN in axes which contained NaN in the input. + """ + if data.size == 0: + return result + data = np.rollaxis(data, axis, data.ndim) + n = np.isnan(data[..., -1]) + # masked NaN values are ok + if np.ma.isMaskedArray(n): + n = n.filled(False) + if result.ndim == 0: + if n == True: + warnings.warn("Invalid value encountered in median", + RuntimeWarning, stacklevel=3) + if out is not None: + out[...] = data.dtype.type(np.nan) + result = out + else: + result = data.dtype.type(np.nan) + elif np.count_nonzero(n.ravel()) > 0: + warnings.warn("Invalid value encountered in median for" + + " %d results" % np.count_nonzero(n.ravel()), + RuntimeWarning, stacklevel=3) + result[n] = np.nan + return result + #----------------------------------------------------------------------------- |