summaryrefslogtreecommitdiff
path: root/numpy/core/_methods.py
diff options
context:
space:
mode:
authorFrederic Bastien <nouiz@nouiz.org>2016-10-31 11:12:04 -0400
committerFrederic Bastien <nouiz@nouiz.org>2016-10-31 11:12:04 -0400
commite1e76fefbd5a41ae14308a43245b4ecdf3099252 (patch)
treea9794a2006e8702f81c2ee92961a8df3b7b195e9 /numpy/core/_methods.py
parent530d67191287b8dc625bfc49664ad08b22e9e4d3 (diff)
downloadnumpy-e1e76fefbd5a41ae14308a43245b4ecdf3099252.tar.gz
Simplify and still reuse out with float16 inputs.
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r--numpy/core/_methods.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index b53c5ca00..4fdda242d 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -53,8 +53,8 @@ def _count_reduce_items(arr, axis):
def _mean(a, axis=None, dtype=None, out=None, keepdims=False):
arr = asanyarray(a)
+ is_float16_result = False
rcount = _count_reduce_items(arr, axis)
- orig_dtype = dtype
# Make this warning show up first
if rcount == 0:
warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
@@ -65,16 +65,17 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False):
dtype = mu.dtype('f8')
elif issubclass(arr.dtype.type, nt.float16):
dtype = mu.dtype('f4')
+ is_float16_result = True
ret = umr_sum(arr, axis, dtype, out, keepdims)
if isinstance(ret, mu.ndarray):
ret = um.true_divide(
ret, rcount, out=ret, casting='unsafe', subok=False)
- if orig_dtype is None and issubclass(arr.dtype.type, nt.float16):
+ if is_float16_result and out is None:
ret = a.dtype.type(ret)
elif hasattr(ret, 'dtype'):
- if orig_dtype is None and issubclass(arr.dtype.type, nt.float16):
- ret = a.dtype.type(ret / rcount)
+ if is_float16_result:
+ ret = nt.float16(ret / rcount)
else:
ret = ret.dtype.type(ret / rcount)
else: