summaryrefslogtreecommitdiff
path: root/numpy/core/_methods.py
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-12-02 12:39:26 -0600
committerSebastian Berg <sebastian@sipsolutions.net>2020-12-02 14:41:59 -0600
commit755ea9a83e410c226c24b95cb892da3f64248d42 (patch)
treec69d51ad38cabef2c3bfe3ded013ce808d1536b6 /numpy/core/_methods.py
parent33dc7bea24f1ab6c47047b49521e732caeb485d5 (diff)
downloadnumpy-755ea9a83e410c226c24b95cb892da3f64248d42.tar.gz
ENH: Micro-optimize where=True path for mean, var, any, and all
This removes a 20%-30% overhead, and thus the largest chunk of slowdown incurred by adding the `where` argument. Most other places have fast-paths for `where=True`, this one also should have it. The additional argument does slow down the function versions a bit more than this, but that is to be expected probably (it has to build a new argument dict, at some point we might want to move this to C, but that seems worth much more with FASTCALL logic).
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r--numpy/core/_methods.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 75fd32ec8..c730e2035 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -51,9 +51,15 @@ def _prod(a, axis=None, dtype=None, out=None, keepdims=False,
return umr_prod(a, axis, dtype, out, keepdims, initial, where)
def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
+ # Parsing keyword arguments is currently fairly slow, so avoid it for now
+ if where is True:
+ return umr_any(a, axis, dtype, out, keepdims)
return umr_any(a, axis, dtype, out, keepdims, where=where)
def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
+ # Parsing keyword arguments is currently fairly slow, so avoid it for now
+ if where is True:
+ return umr_all(a, axis, dtype, out, keepdims)
return umr_all(a, axis, dtype, out, keepdims, where=where)
def _count_reduce_items(arr, axis, keepdims=False, where=True):
@@ -158,7 +164,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
is_float16_result = False
rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
- if umr_any(rcount == 0, axis=None):
+ if rcount == 0 if where is True else umr_any(rcount == 0):
warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
# Cast bool, unsigned int, and int to float64 by default
@@ -191,7 +197,7 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
# Make this warning show up on top.
- if umr_any(ddof >= rcount, axis=None):
+ if ddof >= rcount if where is True else umr_any(ddof >= rcount):
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning,
stacklevel=2)