From c898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9 Mon Sep 17 00:00:00 2001 From: Ross Barnowski Date: Tue, 3 Mar 2020 15:02:49 -0800 Subject: ENH: Adds a fast path to var for complex input var currently has a conditional that results in conjugate being called for the variance calculation of complex inputs. This leg of the computation is slow. This PR avoids this computational leg for complex inputs via a type check. Closes #15684 --- numpy/core/_methods.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'numpy/core/_methods.py') diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 694523b20..8a90731e9 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -21,6 +21,21 @@ umr_prod = um.multiply.reduce umr_any = um.logical_or.reduce umr_all = um.logical_and.reduce +# Complex types to -> (2,)float view for fast-path computation in _var() +_complex_to_float = { + nt.dtype(nt.csingle) : nt.dtype(nt.single), + nt.dtype(nt.cdouble) : nt.dtype(nt.double), +} +# Special case for windows: ensure double takes precedence +if nt.dtype(nt.longdouble) != nt.dtype(nt.double): + _complex_to_float.update({ + nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble), + }) +# Add reverse-endian types +_complex_to_float.update({ + k.newbyteorder() : v.newbyteorder() for k, v in _complex_to_float.items() +}) + # avoid keyword arguments to speed up parsing, saves about 15%-20% for very # small reductions def _amax(a, axis=None, out=None, keepdims=False, @@ -189,8 +204,16 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # Note that x may not be inexact and that we need it to be an array, # not a scalar. x = asanyarray(arr - arrmean) + if issubclass(arr.dtype.type, (nt.floating, nt.integer)): x = um.multiply(x, x, out=x) + # Fast-paths for built-in complex types + elif x.dtype in _complex_to_float: + xv = x.view(dtype=(_complex_to_float[x.dtype], (2,))) + um.multiply(xv, xv, out=xv) + x = um.add(xv[..., 0], xv[..., 1], out=x.real).real + # Most general case; includes handling object arrays containing imaginary + # numbers and complex types with non-native byteorder else: x = um.multiply(x, um.conjugate(x), out=x).real -- cgit v1.2.1