diff options
author | Simon Gasse <sgasse@users.noreply.github.com> | 2020-03-26 22:09:34 +0100 |
---|---|---|
committer | Simon Gasse <sgasse@users.noreply.github.com> | 2020-07-18 13:10:10 +0200 |
commit | 4ec1dbd8864363f77902d77e3d044a26eead31be (patch) | |
tree | 06a55d49ad04b2d6e3ded1f01b91578c3c3eeff7 | |
parent | 6ef5ec39cdfaf77aa4600ec2e3bf9f679a4fd527 (diff) | |
download | numpy-4ec1dbd8864363f77902d77e3d044a26eead31be.tar.gz |
ENH: Add where argument to several functions
Harmonize the signature of np.mean, np.var np.std, np.any, np.all,
and their respective nd.array methods with np.sum by adding a where
argument, see gh-15818.
-rw-r--r-- | doc/release/upcoming_changes/15852.new_feature.rst | 24 | ||||
-rw-r--r-- | numpy/core/_methods.py | 80 | ||||
-rw-r--r-- | numpy/core/fromnumeric.py | 111 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 132 |
4 files changed, 298 insertions, 49 deletions
diff --git a/doc/release/upcoming_changes/15852.new_feature.rst b/doc/release/upcoming_changes/15852.new_feature.rst new file mode 100644 index 000000000..12965e57b --- /dev/null +++ b/doc/release/upcoming_changes/15852.new_feature.rst @@ -0,0 +1,24 @@ +``where`` keyword argument for ``numpy.all`` and ``numpy.any`` functions +------------------------------------------------------------------------ +The keyword argument ``where`` is added and allows to only consider specified +elements or subaxes from an array in the Boolean evaluation of ``all`` and +``any``. This new keyword is available to the functions ``all`` and ``any`` +both via ``numpy`` directly or in the methods of ``numpy.ndarray``. + +Any broadcastable Boolean array or a scalar can be set as ``where``. It +defaults to ``True`` to evaluate the functions for all elements in an array if +``where`` is not set by the user. Examples are given in the documentation of +the functions. + + +``where`` keyword argument for ``numpy`` functions ``mean``, ``std``, ``var`` +----------------------------------------------------------------------------- +The keyword argument ``where`` is added and allows to limit the scope in the +caluclation of ``mean``, ``std`` and ``var`` to only a subset of elements. It +is available both via ``numpy`` directly or in the methods of +``numpy.ndarray``. + +Any broadcastable Boolean array or a scalar can be set as ``where``. It +defaults to ``True`` to evaluate the functions for all elements in an array if +``where`` is not set by the user. Examples are given in the documentation of +the functions. diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 86ddf4d17..75fd32ec8 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -50,20 +50,32 @@ def _prod(a, axis=None, dtype=None, out=None, keepdims=False, initial=_NoValue, where=True): return umr_prod(a, axis, dtype, out, keepdims, initial, where) -def _any(a, axis=None, dtype=None, out=None, keepdims=False): - return umr_any(a, axis, dtype, out, keepdims) - -def _all(a, axis=None, dtype=None, out=None, keepdims=False): - return umr_all(a, axis, dtype, out, keepdims) - -def _count_reduce_items(arr, axis): - if axis is None: - axis = tuple(range(arr.ndim)) - if not isinstance(axis, tuple): - axis = (axis,) - items = 1 - for ax in axis: - items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)] +def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + return umr_any(a, axis, dtype, out, keepdims, where=where) + +def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + return umr_all(a, axis, dtype, out, keepdims, where=where) + +def _count_reduce_items(arr, axis, keepdims=False, where=True): + # fast-path for the default case + if where is True: + # no boolean mask given, calculate items according to axis + if axis is None: + axis = tuple(range(arr.ndim)) + elif not isinstance(axis, tuple): + axis = (axis,) + items = nt.intp(1) + for ax in axis: + items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)] + else: + # TODO: Optimize case when `where` is broadcast along a non-reduction + # axis and full sum is more excessive than needed. + + # guarded to protect circular imports + from numpy.lib.stride_tricks import broadcast_to + # count True values in (potentially broadcasted) boolean mask + items = umr_sum(broadcast_to(where, arr.shape), axis, nt.intp, None, + keepdims) return items # Numpy 1.17.0, 2019-02-24 @@ -140,13 +152,13 @@ def _clip(a, min=None, max=None, out=None, *, casting=None, **kwargs): return _clip_dep_invoke_with_casting( um.clip, a, min, max, out=out, casting=casting, **kwargs) -def _mean(a, axis=None, dtype=None, out=None, keepdims=False): +def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): arr = asanyarray(a) is_float16_result = False - rcount = _count_reduce_items(arr, axis) - # Make this warning show up first - if rcount == 0: + + rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) + if umr_any(rcount == 0, axis=None): warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2) # Cast bool, unsigned int, and int to float64 by default @@ -157,7 +169,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False): dtype = mu.dtype('f4') is_float16_result = True - ret = umr_sum(arr, axis, dtype, out, keepdims) + ret = umr_sum(arr, axis, dtype, out, keepdims, where=where) if isinstance(ret, mu.ndarray): ret = um.true_divide( ret, rcount, out=ret, casting='unsafe', subok=False) @@ -173,12 +185,13 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False): return ret -def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): +def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, + where=True): arr = asanyarray(a) - rcount = _count_reduce_items(arr, axis) + rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) # Make this warning show up on top. - if ddof >= rcount: + if umr_any(ddof >= rcount, axis=None): warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) @@ -189,10 +202,18 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # Compute the mean. # Note that if dtype is not of inexact type then arraymean will # not be either. - arrmean = umr_sum(arr, axis, dtype, keepdims=True) + arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where) + # The shape of rcount has to match arrmean to not change the shape of out + # in broadcasting. Otherwise, it cannot be stored back to arrmean. + if rcount.ndim == 0: + # fast-path for default case when where is True + div = rcount + else: + # matching rcount to arrmean when where is specified as array + div = rcount.reshape(arrmean.shape) if isinstance(arrmean, mu.ndarray): - arrmean = um.true_divide( - arrmean, rcount, out=arrmean, casting='unsafe', subok=False) + arrmean = um.true_divide(arrmean, div, out=arrmean, casting='unsafe', + subok=False) else: arrmean = arrmean.dtype.type(arrmean / rcount) @@ -213,10 +234,10 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): else: x = um.multiply(x, um.conjugate(x), out=x).real - ret = umr_sum(x, axis, dtype, out, keepdims) + ret = umr_sum(x, axis, dtype, out, keepdims=keepdims, where=where) # Compute degrees of freedom and make sure it is not negative. - rcount = max([rcount - ddof, 0]) + rcount = um.maximum(rcount - ddof, 0) # divide by degrees of freedom if isinstance(ret, mu.ndarray): @@ -229,9 +250,10 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): return ret -def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): +def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, + where=True): ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof, - keepdims=keepdims) + keepdims=keepdims, where=where) if isinstance(ret, mu.ndarray): ret = um.sqrt(ret, out=ret) diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index f8c11c015..dcbb84f33 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -2253,12 +2253,13 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=np._NoValue, initial=initial, where=where) -def _any_dispatcher(a, axis=None, out=None, keepdims=None): - return (a, out) +def _any_dispatcher(a, axis=None, out=None, keepdims=None, *, + where=np._NoValue): + return (a, where, out) @array_function_dispatch(_any_dispatcher) -def any(a, axis=None, out=None, keepdims=np._NoValue): +def any(a, axis=None, out=None, keepdims=np._NoValue, *, where=np._NoValue): """ Test whether any array element along a given axis evaluates to True. @@ -2296,6 +2297,12 @@ def any(a, axis=None, out=None, keepdims=np._NoValue): sub-class' method does not implement `keepdims` any exceptions will be raised. + where : array_like of bool, optional + Elements to include in checking for any `True` values. + See `~numpy.ufunc.reduce` for details. + + .. versionadded:: 1.20.0 + Returns ------- any : bool or ndarray @@ -2327,6 +2334,9 @@ def any(a, axis=None, out=None, keepdims=np._NoValue): >>> np.any(np.nan) True + >>> np.any([[True, False], [False, False]], where=[[False], [True]]) + False + >>> o=np.array(False) >>> z=np.any([-1, 4, 5], out=o) >>> z, o @@ -2338,15 +2348,17 @@ def any(a, axis=None, out=None, keepdims=np._NoValue): (191614240, 191614240) """ - return _wrapreduction(a, np.logical_or, 'any', axis, None, out, keepdims=keepdims) + return _wrapreduction(a, np.logical_or, 'any', axis, None, out, + keepdims=keepdims, where=where) -def _all_dispatcher(a, axis=None, out=None, keepdims=None): - return (a, out) +def _all_dispatcher(a, axis=None, out=None, keepdims=None, *, + where=None): + return (a, where, out) @array_function_dispatch(_all_dispatcher) -def all(a, axis=None, out=None, keepdims=np._NoValue): +def all(a, axis=None, out=None, keepdims=np._NoValue, *, where=np._NoValue): """ Test whether all array elements along a given axis evaluate to True. @@ -2382,6 +2394,12 @@ def all(a, axis=None, out=None, keepdims=np._NoValue): sub-class' method does not implement `keepdims` any exceptions will be raised. + where : array_like of bool, optional + Elements to include in checking for all `True` values. + See `~numpy.ufunc.reduce` for details. + + .. versionadded:: 1.20.0 + Returns ------- all : ndarray, bool @@ -2413,13 +2431,17 @@ def all(a, axis=None, out=None, keepdims=np._NoValue): >>> np.all([1.0, np.nan]) True + >>> np.all([[True, True], [False, True]], where=[[True], [False]]) + True + >>> o=np.array(False) >>> z=np.all([-1, 4, 5], out=o) >>> id(z), id(o), z (28293632, 28293632, array(True)) # may vary """ - return _wrapreduction(a, np.logical_and, 'all', axis, None, out, keepdims=keepdims) + return _wrapreduction(a, np.logical_and, 'all', axis, None, out, + keepdims=keepdims, where=where) def _cumsum_dispatcher(a, axis=None, dtype=None, out=None): @@ -3276,12 +3298,14 @@ def around(a, decimals=0, out=None): return _wrapfunc(a, 'round', decimals=decimals, out=out) -def _mean_dispatcher(a, axis=None, dtype=None, out=None, keepdims=None): - return (a, out) +def _mean_dispatcher(a, axis=None, dtype=None, out=None, keepdims=None, *, + where=None): + return (a, where, out) @array_function_dispatch(_mean_dispatcher) -def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): +def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue, *, + where=np._NoValue): """ Compute the arithmetic mean along the specified axis. @@ -3323,6 +3347,11 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): sub-class' method does not implement `keepdims` any exceptions will be raised. + where : array_like of bool, optional + Elements to include in the mean. See `~numpy.ufunc.reduce` for details. + + .. versionadded:: 1.20.0 + Returns ------- m : ndarray, see dtype parameter above @@ -3371,10 +3400,19 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): >>> np.mean(a, dtype=np.float64) 0.55000000074505806 # may vary + Specifying a where argument: + >>> a = np.array([[5, 9, 13], [14, 10, 12], [11, 15, 19]]) + >>> np.mean(a) + 12.0 + >>> np.mean(a, where=[[True], [False], [False]]) + 9.0 + """ kwargs = {} if keepdims is not np._NoValue: kwargs['keepdims'] = keepdims + if where is not np._NoValue: + kwargs['where'] = where if type(a) is not mu.ndarray: try: mean = a.mean @@ -3387,13 +3425,14 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): out=out, **kwargs) -def _std_dispatcher( - a, axis=None, dtype=None, out=None, ddof=None, keepdims=None): - return (a, out) +def _std_dispatcher(a, axis=None, dtype=None, out=None, ddof=None, + keepdims=None, *, where=None): + return (a, where, out) @array_function_dispatch(_std_dispatcher) -def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): +def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue, *, + where=np._NoValue): """ Compute the standard deviation along the specified axis. @@ -3436,6 +3475,12 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): sub-class' method does not implement `keepdims` any exceptions will be raised. + where : array_like of bool, optional + Elements to include in the standard deviation. + See `~numpy.ufunc.reduce` for details. + + .. versionadded:: 1.20.0 + Returns ------- standard_deviation : ndarray, see dtype parameter above. @@ -3495,11 +3540,20 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): >>> np.std(a, dtype=np.float64) 0.44999999925494177 # may vary + Specifying a where argument: + + >>> a = np.array([[14, 8, 11, 10], [7, 9, 10, 11], [10, 15, 5, 10]]) + >>> np.std(a) + 2.614064523559687 # may vary + >>> np.std(a, where=[[True], [True], [False]]) + 2.0 + """ kwargs = {} if keepdims is not np._NoValue: kwargs['keepdims'] = keepdims - + if where is not np._NoValue: + kwargs['where'] = where if type(a) is not mu.ndarray: try: std = a.std @@ -3512,13 +3566,14 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): **kwargs) -def _var_dispatcher( - a, axis=None, dtype=None, out=None, ddof=None, keepdims=None): - return (a, out) +def _var_dispatcher(a, axis=None, dtype=None, out=None, ddof=None, + keepdims=None, *, where=None): + return (a, where, out) @array_function_dispatch(_var_dispatcher) -def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): +def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue, *, + where=np._NoValue): """ Compute the variance along the specified axis. @@ -3562,6 +3617,12 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): sub-class' method does not implement `keepdims` any exceptions will be raised. + where : array_like of bool, optional + Elements to include in the variance. See `~numpy.ufunc.reduce` for + details. + + .. versionadded:: 1.20.0 + Returns ------- variance : ndarray, see dtype parameter above @@ -3619,10 +3680,20 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): >>> ((1-0.55)**2 + (0.1-0.55)**2)/2 0.2025 + Specifying a where argument: + + >>> a = np.array([[14, 8, 11, 10], [7, 9, 10, 11], [10, 15, 5, 10]]) + >>> np.var(a) + 6.833333333333333 # may vary + >>> np.var(a, where=[[True], [True], [False]]) + 4.0 + """ kwargs = {} if keepdims is not np._NoValue: kwargs['keepdims'] = keepdims + if where is not np._NoValue: + kwargs['where'] = where if type(a) is not mu.ndarray: try: diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index b7d4a6a92..fca646acb 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1587,6 +1587,47 @@ class TestMethods: sort_kinds = ['quicksort', 'heapsort', 'stable'] + def test_all_where(self): + a = np.array([[True, False, True], + [False, False, False], + [True, True, True]]) + wh_full = np.array([[True, False, True], + [False, False, False], + [True, False, True]]) + wh_lower = np.array([[False], + [False], + [True]]) + for _ax in [0, None]: + assert_equal(a.all(axis=_ax, where=wh_lower), + np.all(a[wh_lower[:,0],:], axis=_ax)) + assert_equal(np.all(a, axis=_ax, where=wh_lower), + a[wh_lower[:,0],:].all(axis=_ax)) + + assert_equal(a.all(where=wh_full), True) + assert_equal(np.all(a, where=wh_full), True) + assert_equal(a.all(where=False), True) + assert_equal(np.all(a, where=False), True) + + def test_any_where(self): + a = np.array([[True, False, True], + [False, False, False], + [True, True, True]]) + wh_full = np.array([[False, True, False], + [True, True, True], + [False, False, False]]) + wh_middle = np.array([[False], + [True], + [False]]) + for _ax in [0, None]: + assert_equal(a.any(axis=_ax, where=wh_middle), + np.any(a[wh_middle[:,0],:], axis=_ax)) + assert_equal(np.any(a, axis=_ax, where=wh_middle), + a[wh_middle[:,0],:].any(axis=_ax)) + assert_equal(a.any(where=wh_full), False) + assert_equal(np.any(a, where=wh_full), False) + assert_equal(a.any(where=False), False) + assert_equal(np.any(a, where=False), False) + def test_compress(self): tgt = [[5, 6, 7, 8, 9]] arr = np.arange(10).reshape(2, 5) @@ -5575,6 +5616,33 @@ class TestStats: with assert_raises(np.core._exceptions.AxisError): np.arange(10).mean(axis=2) + def test_mean_where(self): + a = np.arange(16).reshape((4, 4)) + wh_full = np.array([[False, True, False, True], + [True, False, True, False], + [True, True, False, False], + [False, False, True, True]]) + wh_partial = np.array([[False], + [True], + [True], + [False]]) + _cases = [(1, True, [1.5, 5.5, 9.5, 13.5]), + (0, wh_full, [6., 5., 10., 9.]), + (1, wh_full, [2., 5., 8.5, 14.5]), + (0, wh_partial, [6., 7., 8., 9.])] + for _ax, _wh, _res in _cases: + assert_allclose(a.mean(axis=_ax, where=_wh), + np.array(_res)) + assert_allclose(np.mean(a, axis=_ax, where=_wh), + np.array(_res)) + with pytest.warns(RuntimeWarning) as w: + assert_allclose(a.mean(axis=1, where=wh_partial), + np.array([np.nan, 5.5, 9.5, np.nan])) + with pytest.warns(RuntimeWarning) as w: + assert_equal(a.mean(where=False), np.nan) + with pytest.warns(RuntimeWarning) as w: + assert_equal(np.mean(a, where=False), np.nan) + def test_var_values(self): for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1, None]: @@ -5623,6 +5691,34 @@ class TestStats: with assert_raises(np.core._exceptions.AxisError): np.arange(10).var(axis=2) + def test_var_where(self): + a = np.arange(25).reshape((5, 5)) + wh_full = np.array([[False, True, False, True, True], + [True, False, True, True, False], + [True, True, False, False, True], + [False, True, True, False, True], + [True, False, True, True, False]]) + wh_partial = np.array([[False], + [True], + [True], + [False], + [True]]) + _cases = [(0, True, [50., 50., 50., 50., 50.]), + (1, True, [2., 2., 2., 2., 2.])] + for _ax, _wh, _res in _cases: + assert_allclose(a.var(axis=_ax, where=_wh), + np.array(_res)) + assert_allclose(np.var(a, axis=_ax, where=_wh), + np.array(_res)) + assert_allclose(np.var(a, axis=1, where=wh_full), + np.var(a[wh_full].reshape((5, 3)), axis=1)) + assert_allclose(np.var(a, axis=0, where=wh_partial), + np.var(a[wh_partial[:,0]], axis=0)) + with pytest.warns(RuntimeWarning) as w: + assert_equal(a.var(where=False), np.nan) + with pytest.warns(RuntimeWarning) as w: + assert_equal(np.var(a, where=False), np.nan) + def test_std_values(self): for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1, None]: @@ -5630,6 +5726,42 @@ class TestStats: res = _std(mat, axis=axis) assert_almost_equal(res, tgt) + def test_std_where(self): + a = np.arange(25).reshape((5,5))[::-1] + whf = np.array([[False, True, False, True, True], + [True, False, True, False, True], + [True, True, False, True, False], + [True, False, True, True, False], + [False, True, False, True, True]]) + whp = np.array([[False], + [False], + [True], + [True], + [False]]) + _cases = [ + (0, True, 7.07106781*np.ones((5))), + (1, True, 1.41421356*np.ones((5))), + (0, whf, + np.array([4.0824829 , 8.16496581, 5., 7.39509973, 8.49836586])), + (0, whp, 2.5*np.ones((5))) + ] + for _ax, _wh, _res in _cases: + assert_allclose(a.std(axis=_ax, where=_wh), _res) + assert_allclose(np.std(a, axis=_ax, where=_wh), _res) + + assert_allclose(a.std(axis=1, where=whf), + np.std(a[whf].reshape((5,3)), axis=1)) + assert_allclose(np.std(a, axis=1, where=whf), + (a[whf].reshape((5,3))).std(axis=1)) + assert_allclose(a.std(axis=0, where=whp), + np.std(a[whp[:,0]], axis=0)) + assert_allclose(np.std(a, axis=0, where=whp), + (a[whp[:,0]]).std(axis=0)) + with pytest.warns(RuntimeWarning) as w: + assert_equal(a.std(where=False), np.nan) + with pytest.warns(RuntimeWarning) as w: + assert_equal(np.std(a, where=False), np.nan) + def test_subclass(self): class TestArray(np.ndarray): def __new__(cls, data, info): |