diff options
author | Ross Barnowski <rossbar@berkeley.edu> | 2020-03-03 15:02:49 -0800 |
---|---|---|
committer | Ross Barnowski <rossbar@berkeley.edu> | 2020-03-09 10:39:08 -0700 |
commit | c898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9 (patch) | |
tree | 6dc8b71d58e0f64a85134ff8c93b7988b2b942c2 /numpy/core/_methods.py | |
parent | 6894bbc6d396b87464cbc21516d239d5f94f13b7 (diff) | |
download | numpy-c898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9.tar.gz |
ENH: Adds a fast path to var for complex input
var currently has a conditional that results in conjugate being
called for the variance calculation of complex inputs.
This leg of the computation is slow. This PR avoids this
computational leg for complex inputs via a type check.
Closes #15684
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r-- | numpy/core/_methods.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 694523b20..8a90731e9 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -21,6 +21,21 @@ umr_prod = um.multiply.reduce umr_any = um.logical_or.reduce umr_all = um.logical_and.reduce +# Complex types to -> (2,)float view for fast-path computation in _var() +_complex_to_float = { + nt.dtype(nt.csingle) : nt.dtype(nt.single), + nt.dtype(nt.cdouble) : nt.dtype(nt.double), +} +# Special case for windows: ensure double takes precedence +if nt.dtype(nt.longdouble) != nt.dtype(nt.double): + _complex_to_float.update({ + nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble), + }) +# Add reverse-endian types +_complex_to_float.update({ + k.newbyteorder() : v.newbyteorder() for k, v in _complex_to_float.items() +}) + # avoid keyword arguments to speed up parsing, saves about 15%-20% for very # small reductions def _amax(a, axis=None, out=None, keepdims=False, @@ -189,8 +204,16 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # Note that x may not be inexact and that we need it to be an array, # not a scalar. x = asanyarray(arr - arrmean) + if issubclass(arr.dtype.type, (nt.floating, nt.integer)): x = um.multiply(x, x, out=x) + # Fast-paths for built-in complex types + elif x.dtype in _complex_to_float: + xv = x.view(dtype=(_complex_to_float[x.dtype], (2,))) + um.multiply(xv, xv, out=xv) + x = um.add(xv[..., 0], xv[..., 1], out=x.real).real + # Most general case; includes handling object arrays containing imaginary + # numbers and complex types with non-native byteorder else: x = um.multiply(x, um.conjugate(x), out=x).real |