summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMateusz Sokół <8431159+mtsokol@users.noreply.github.com>2021-03-11 23:59:16 +0100
committerGitHub <noreply@github.com>2021-03-11 16:59:16 -0600
commitc5de5b5c2cf048e1556f31dfcfa031c8f624b98e (patch)
tree985787b74ce260686a25d80f16e971ddbd3479e7 /numpy
parente8d20b5731e965127c6157a1f34f6970a8ae550c (diff)
downloadnumpy-c5de5b5c2cf048e1556f31dfcfa031c8f624b98e.tar.gz
BUG: Fixed ``where`` keyword for ``np.mean`` & ``np.var`` methods (gh-18560)
* Fixed keyword bug * Added test case * Reverted to original notation * Added tests for var and std Closes gh-18552
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_methods.py4
-rw-r--r--numpy/core/tests/test_multiarray.py26
2 files changed, 28 insertions, 2 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 1867ba68c..09147fe5b 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -165,7 +165,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
is_float16_result = False
rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
- if rcount == 0 if where is True else umr_any(rcount == 0):
+ if rcount == 0 if where is True else umr_any(rcount == 0, axis=None):
warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
# Cast bool, unsigned int, and int to float64 by default
@@ -198,7 +198,7 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
# Make this warning show up on top.
- if ddof >= rcount if where is True else umr_any(ddof >= rcount):
+ if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None):
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning,
stacklevel=2)
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index b30fcb812..cffb1af99 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -5720,6 +5720,15 @@ class TestStats:
np.array(_res))
assert_allclose(np.mean(a, axis=_ax, where=_wh),
np.array(_res))
+
+ a3d = np.arange(16).reshape((2, 2, 4))
+ _wh_partial = np.array([False, True, True, False])
+ _res = [[1.5, 5.5], [9.5, 13.5]]
+ assert_allclose(a3d.mean(axis=2, where=_wh_partial),
+ np.array(_res))
+ assert_allclose(np.mean(a3d, axis=2, where=_wh_partial),
+ np.array(_res))
+
with pytest.warns(RuntimeWarning) as w:
assert_allclose(a.mean(axis=1, where=wh_partial),
np.array([np.nan, 5.5, 9.5, np.nan]))
@@ -5795,6 +5804,15 @@ class TestStats:
np.array(_res))
assert_allclose(np.var(a, axis=_ax, where=_wh),
np.array(_res))
+
+ a3d = np.arange(16).reshape((2, 2, 4))
+ _wh_partial = np.array([False, True, True, False])
+ _res = [[0.25, 0.25], [0.25, 0.25]]
+ assert_allclose(a3d.var(axis=2, where=_wh_partial),
+ np.array(_res))
+ assert_allclose(np.var(a3d, axis=2, where=_wh_partial),
+ np.array(_res))
+
assert_allclose(np.var(a, axis=1, where=wh_full),
np.var(a[wh_full].reshape((5, 3)), axis=1))
assert_allclose(np.var(a, axis=0, where=wh_partial),
@@ -5834,6 +5852,14 @@ class TestStats:
assert_allclose(a.std(axis=_ax, where=_wh), _res)
assert_allclose(np.std(a, axis=_ax, where=_wh), _res)
+ a3d = np.arange(16).reshape((2, 2, 4))
+ _wh_partial = np.array([False, True, True, False])
+ _res = [[0.5, 0.5], [0.5, 0.5]]
+ assert_allclose(a3d.std(axis=2, where=_wh_partial),
+ np.array(_res))
+ assert_allclose(np.std(a3d, axis=2, where=_wh_partial),
+ np.array(_res))
+
assert_allclose(a.std(axis=1, where=whf),
np.std(a[whf].reshape((5,3)), axis=1))
assert_allclose(np.std(a, axis=1, where=whf),