summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-09-05 12:58:06 -0600
committerCharles Harris <charlesr.harris@gmail.com>2016-09-06 07:43:41 -0600
commit43899e19e9a34fbdee16091cf7b46d7bf4c1d486 (patch)
tree6bb3677eb8ffc226de3b77d7e9d1d62825bf43de
parent346efba294d97cca63be3f9c3021ecf7df5ba92e (diff)
downloadnumpy-43899e19e9a34fbdee16091cf7b46d7bf4c1d486.tar.gz
TST: Add ma.median tests for valid axis.
-rw-r--r--numpy/ma/tests/test_extras.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 09836fc46..56d3dfd41 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -10,6 +10,7 @@ Adapted from the original test_ma by Pierre Gerard-Marchant
from __future__ import division, absolute_import, print_function
import warnings
+import itertools
import numpy as np
from numpy.testing import (
@@ -684,6 +685,37 @@ class TestMedian(TestCase):
assert_equal(ma_x.shape, (2,), "shape mismatch")
assert_(type(ma_x) is MaskedArray)
+ def test_axis_argument_errors(self):
+ msg = "mask = %s, ndim = %s, axis = %s, overwrite_input = %s"
+ for ndmin in range(5):
+ for mask in [False, True]:
+ x = array(1, ndmin=ndmin, mask=mask)
+
+ # Valid axis values should not raise exception
+ args = itertools.product(range(-ndmin, ndmin), [False, True])
+ for axis, over in args:
+ try:
+ np.ma.median(x, axis=axis, overwrite_input=over)
+ except:
+ raise AssertionError(msg % (mask, ndmin, axis, over))
+
+ # Invalid axis values should raise exception
+ args = itertools.product([-(ndmin + 1), ndmin], [False, True])
+ for axis, over in args:
+ try:
+ np.ma.median(x, axis=axis, overwrite_input=over)
+ except IndexError:
+ pass
+ else:
+ raise AssertionError(msg % (mask, ndmin, axis, over))
+
+ def test_masked_0d(self):
+ # Check values
+ x = array(1, mask=False)
+ assert_equal(np.ma.median(x), 1)
+ x = array(1, mask=True)
+ assert_equal(np.ma.median(x), np.ma.masked)
+
def test_masked_1d(self):
x = array(np.arange(5), mask=True)
assert_equal(np.ma.median(x), np.ma.masked)