summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py18
-rw-r--r--numpy/lib/tests/test_nanfunctions.py42
-rw-r--r--numpy/lib/utils.py46
3 files changed, 77 insertions, 29 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 352512513..172e9a322 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3982,23 +3982,7 @@ def _median(a, axis=None, out=None, overwrite_input=False):
if np.issubdtype(a.dtype, np.inexact) and sz > 0:
# warn and return nans like mean would
rout = mean(part[indexer], axis=axis, out=out)
- part = np.rollaxis(part, axis, part.ndim)
- n = np.isnan(part[..., -1])
- if rout.ndim == 0:
- if n == True:
- warnings.warn("Invalid value encountered in median",
- RuntimeWarning, stacklevel=3)
- if out is not None:
- out[...] = a.dtype.type(np.nan)
- rout = out
- else:
- rout = a.dtype.type(np.nan)
- elif np.count_nonzero(n.ravel()) > 0:
- warnings.warn("Invalid value encountered in median for" +
- " %d results" % np.count_nonzero(n.ravel()),
- RuntimeWarning, stacklevel=3)
- rout[n] = np.nan
- return rout
+ return np.lib.utils._median_nancheck(part, rout, axis, out)
else:
# if there are no nans
# Use mean in odd and even case to coerce data type
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 06c0953b5..18fcb2887 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -693,18 +693,36 @@ class TestNanFunctions_Median(TestCase):
def test_float_special(self):
with suppress_warnings() as sup:
sup.filter(RuntimeWarning)
- a = np.array([[np.inf, np.nan], [np.nan, np.nan]])
- assert_equal(np.nanmedian(a, axis=0), [np.inf, np.nan])
- assert_equal(np.nanmedian(a, axis=1), [np.inf, np.nan])
- assert_equal(np.nanmedian(a), np.inf)
-
- # minimum fill value check
- a = np.array([[np.nan, np.nan, np.inf], [np.nan, np.nan, np.inf]])
- assert_equal(np.nanmedian(a, axis=1), np.inf)
-
- # no mask path
- a = np.array([[np.inf, np.inf], [np.inf, np.inf]])
- assert_equal(np.nanmedian(a, axis=1), np.inf)
+ for inf in [np.inf, -np.inf]:
+ a = np.array([[inf, np.nan], [np.nan, np.nan]])
+ assert_equal(np.nanmedian(a, axis=0), [inf, np.nan])
+ assert_equal(np.nanmedian(a, axis=1), [inf, np.nan])
+ assert_equal(np.nanmedian(a), inf)
+
+ # minimum fill value check
+ a = np.array([[np.nan, np.nan, inf],
+ [np.nan, np.nan, inf]])
+ assert_equal(np.nanmedian(a), inf)
+ assert_equal(np.nanmedian(a, axis=0), [np.nan, np.nan, inf])
+ assert_equal(np.nanmedian(a, axis=1), inf)
+
+ # no mask path
+ a = np.array([[inf, inf], [inf, inf]])
+ assert_equal(np.nanmedian(a, axis=1), inf)
+
+ for i in range(0, 10):
+ for j in range(1, 10):
+ a = np.array([([np.nan] * i) + ([inf] * j)] * 2)
+ assert_equal(np.nanmedian(a), inf)
+ assert_equal(np.nanmedian(a, axis=1), inf)
+ assert_equal(np.nanmedian(a, axis=0),
+ ([np.nan] * i) + [inf] * j)
+
+ a = np.array([([np.nan] * i) + ([-inf] * j)] * 2)
+ assert_equal(np.nanmedian(a), -inf)
+ assert_equal(np.nanmedian(a, axis=1), -inf)
+ assert_equal(np.nanmedian(a, axis=0),
+ ([np.nan] * i) + [-inf] * j)
class TestNanFunctions_Percentile(TestCase):
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index 5c364268c..61aa5e33b 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -8,6 +8,7 @@ import warnings
from numpy.core.numerictypes import issubclass_, issubsctype, issubdtype
from numpy.core import ndarray, ufunc, asarray
+import numpy as np
# getargspec and formatargspec were removed in Python 3.6
from numpy.compat import getargspec, formatargspec
@@ -1113,4 +1114,49 @@ def safe_eval(source):
import ast
return ast.literal_eval(source)
+
+
+def _median_nancheck(data, result, axis, out):
+ """
+ Utility function to check median result from data for NaN values at the end
+ and return NaN in that case. Input result can also be a MaskedArray.
+
+ Parameters
+ ----------
+ data : array
+ Input data to median function
+ result : Array or MaskedArray
+ Result of median function
+ axis : {int, sequence of int, None}, optional
+ Axis or axes along which the median was computed.
+ out : ndarray, optional
+ Output array in which to place the result.
+ Returns
+ -------
+ median : scalar or ndarray
+ Median or NaN in axes which contained NaN in the input.
+ """
+ if data.size == 0:
+ return result
+ data = np.rollaxis(data, axis, data.ndim)
+ n = np.isnan(data[..., -1])
+ # masked NaN values are ok
+ if np.ma.isMaskedArray(n):
+ n = n.filled(False)
+ if result.ndim == 0:
+ if n == True:
+ warnings.warn("Invalid value encountered in median",
+ RuntimeWarning, stacklevel=3)
+ if out is not None:
+ out[...] = data.dtype.type(np.nan)
+ result = out
+ else:
+ result = data.dtype.type(np.nan)
+ elif np.count_nonzero(n.ravel()) > 0:
+ warnings.warn("Invalid value encountered in median for" +
+ " %d results" % np.count_nonzero(n.ravel()),
+ RuntimeWarning, stacklevel=3)
+ result[n] = np.nan
+ return result
+
#-----------------------------------------------------------------------------