summaryrefslogtreecommitdiff
path: root/numpy/core/_methods.py
diff options
context:
space:
mode:
authorFrederic Bastien <nouiz@nouiz.org>2016-10-28 20:09:45 -0400
committerFrederic Bastien <nouiz@nouiz.org>2016-10-28 21:33:04 -0400
commit530d67191287b8dc625bfc49664ad08b22e9e4d3 (patch)
tree66853b1486805c1f205f287fd207295d02697ed8 /numpy/core/_methods.py
parent6ae842001332f532e0c76815d49336ecc2b88dde (diff)
downloadnumpy-530d67191287b8dc625bfc49664ad08b22e9e4d3.tar.gz
[ENH]Make numpy.mean() do more precise computation without changing the output dtype that stay in float16.
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r--numpy/core/_methods.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 54e267541..b53c5ca00 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -54,20 +54,29 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False):
arr = asanyarray(a)
rcount = _count_reduce_items(arr, axis)
+ orig_dtype = dtype
# Make this warning show up first
if rcount == 0:
warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
# Cast bool, unsigned int, and int to float64 by default
- if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
- dtype = mu.dtype('f8')
+ if dtype is None:
+ if issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
+ dtype = mu.dtype('f8')
+ elif issubclass(arr.dtype.type, nt.float16):
+ dtype = mu.dtype('f4')
ret = umr_sum(arr, axis, dtype, out, keepdims)
if isinstance(ret, mu.ndarray):
ret = um.true_divide(
ret, rcount, out=ret, casting='unsafe', subok=False)
+ if orig_dtype is None and issubclass(arr.dtype.type, nt.float16):
+ ret = a.dtype.type(ret)
elif hasattr(ret, 'dtype'):
- ret = ret.dtype.type(ret / rcount)
+ if orig_dtype is None and issubclass(arr.dtype.type, nt.float16):
+ ret = a.dtype.type(ret / rcount)
+ else:
+ ret = ret.dtype.type(ret / rcount)
else:
ret = ret / rcount