From 755ea9a83e410c226c24b95cb892da3f64248d42 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Wed, 2 Dec 2020 12:39:26 -0600 Subject: ENH: Micro-optimize where=True path for mean, var, any, and all This removes a 20%-30% overhead, and thus the largest chunk of slowdown incurred by adding the `where` argument. Most other places have fast-paths for `where=True`, this one also should have it. The additional argument does slow down the function versions a bit more than this, but that is to be expected probably (it has to build a new argument dict, at some point we might want to move this to C, but that seems worth much more with FASTCALL logic). --- numpy/core/_methods.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'numpy/core/_methods.py') diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 75fd32ec8..c730e2035 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -51,9 +51,15 @@ def _prod(a, axis=None, dtype=None, out=None, keepdims=False, return umr_prod(a, axis, dtype, out, keepdims, initial, where) def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + # Parsing keyword arguments is currently fairly slow, so avoid it for now + if where is True: + return umr_any(a, axis, dtype, out, keepdims) return umr_any(a, axis, dtype, out, keepdims, where=where) def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + # Parsing keyword arguments is currently fairly slow, so avoid it for now + if where is True: + return umr_all(a, axis, dtype, out, keepdims) return umr_all(a, axis, dtype, out, keepdims, where=where) def _count_reduce_items(arr, axis, keepdims=False, where=True): @@ -158,7 +164,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): is_float16_result = False rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) - if umr_any(rcount == 0, axis=None): + if rcount == 0 if where is True else umr_any(rcount == 0): warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2) # Cast bool, unsigned int, and int to float64 by default @@ -191,7 +197,7 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) # Make this warning show up on top. - if umr_any(ddof >= rcount, axis=None): + if ddof >= rcount if where is True else umr_any(ddof >= rcount): warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) -- cgit v1.2.1