summaryrefslogtreecommitdiff
path: root/numpy/core/_methods.py
diff options
context:
space:
mode:
authorRoss Barnowski <rossbar@berkeley.edu>2020-03-03 15:02:49 -0800
committerRoss Barnowski <rossbar@berkeley.edu>2020-03-09 10:39:08 -0700
commitc898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9 (patch)
tree6dc8b71d58e0f64a85134ff8c93b7988b2b942c2 /numpy/core/_methods.py
parent6894bbc6d396b87464cbc21516d239d5f94f13b7 (diff)
downloadnumpy-c898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9.tar.gz
ENH: Adds a fast path to var for complex input
var currently has a conditional that results in conjugate being called for the variance calculation of complex inputs. This leg of the computation is slow. This PR avoids this computational leg for complex inputs via a type check. Closes #15684
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r--numpy/core/_methods.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 694523b20..8a90731e9 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -21,6 +21,21 @@ umr_prod = um.multiply.reduce
umr_any = um.logical_or.reduce
umr_all = um.logical_and.reduce
+# Complex types to -> (2,)float view for fast-path computation in _var()
+_complex_to_float = {
+ nt.dtype(nt.csingle) : nt.dtype(nt.single),
+ nt.dtype(nt.cdouble) : nt.dtype(nt.double),
+}
+# Special case for windows: ensure double takes precedence
+if nt.dtype(nt.longdouble) != nt.dtype(nt.double):
+ _complex_to_float.update({
+ nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble),
+ })
+# Add reverse-endian types
+_complex_to_float.update({
+ k.newbyteorder() : v.newbyteorder() for k, v in _complex_to_float.items()
+})
+
# avoid keyword arguments to speed up parsing, saves about 15%-20% for very
# small reductions
def _amax(a, axis=None, out=None, keepdims=False,
@@ -189,8 +204,16 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
# Note that x may not be inexact and that we need it to be an array,
# not a scalar.
x = asanyarray(arr - arrmean)
+
if issubclass(arr.dtype.type, (nt.floating, nt.integer)):
x = um.multiply(x, x, out=x)
+ # Fast-paths for built-in complex types
+ elif x.dtype in _complex_to_float:
+ xv = x.view(dtype=(_complex_to_float[x.dtype], (2,)))
+ um.multiply(xv, xv, out=xv)
+ x = um.add(xv[..., 0], xv[..., 1], out=x.real).real
+ # Most general case; includes handling object arrays containing imaginary
+ # numbers and complex types with non-native byteorder
else:
x = um.multiply(x, um.conjugate(x), out=x).real