diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-12-02 12:39:26 -0600 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2020-12-02 14:41:59 -0600 |
commit | 755ea9a83e410c226c24b95cb892da3f64248d42 (patch) | |
tree | c69d51ad38cabef2c3bfe3ded013ce808d1536b6 /numpy/core/_methods.py | |
parent | 33dc7bea24f1ab6c47047b49521e732caeb485d5 (diff) | |
download | numpy-755ea9a83e410c226c24b95cb892da3f64248d42.tar.gz |
ENH: Micro-optimize where=True path for mean, var, any, and all
This removes a 20%-30% overhead, and thus the largest chunk of
slowdown incurred by adding the `where` argument. Most other places
have fast-paths for `where=True`, this one also should have it.
The additional argument does slow down the function versions a bit
more than this, but that is to be expected probably (it has to
build a new argument dict, at some point we might want to move this
to C, but that seems worth much more with FASTCALL logic).
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r-- | numpy/core/_methods.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 75fd32ec8..c730e2035 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -51,9 +51,15 @@ def _prod(a, axis=None, dtype=None, out=None, keepdims=False, return umr_prod(a, axis, dtype, out, keepdims, initial, where) def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + # Parsing keyword arguments is currently fairly slow, so avoid it for now + if where is True: + return umr_any(a, axis, dtype, out, keepdims) return umr_any(a, axis, dtype, out, keepdims, where=where) def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + # Parsing keyword arguments is currently fairly slow, so avoid it for now + if where is True: + return umr_all(a, axis, dtype, out, keepdims) return umr_all(a, axis, dtype, out, keepdims, where=where) def _count_reduce_items(arr, axis, keepdims=False, where=True): @@ -158,7 +164,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): is_float16_result = False rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) - if umr_any(rcount == 0, axis=None): + if rcount == 0 if where is True else umr_any(rcount == 0): warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2) # Cast bool, unsigned int, and int to float64 by default @@ -191,7 +197,7 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) # Make this warning show up on top. - if umr_any(ddof >= rcount, axis=None): + if ddof >= rcount if where is True else umr_any(ddof >= rcount): warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) |