diff options
author | Frederic Bastien <nouiz@nouiz.org> | 2016-10-31 11:12:04 -0400 |
---|---|---|
committer | Frederic Bastien <nouiz@nouiz.org> | 2016-10-31 11:12:04 -0400 |
commit | e1e76fefbd5a41ae14308a43245b4ecdf3099252 (patch) | |
tree | a9794a2006e8702f81c2ee92961a8df3b7b195e9 /numpy/core/_methods.py | |
parent | 530d67191287b8dc625bfc49664ad08b22e9e4d3 (diff) | |
download | numpy-e1e76fefbd5a41ae14308a43245b4ecdf3099252.tar.gz |
Simplify and still reuse out with float16 inputs.
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r-- | numpy/core/_methods.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index b53c5ca00..4fdda242d 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -53,8 +53,8 @@ def _count_reduce_items(arr, axis): def _mean(a, axis=None, dtype=None, out=None, keepdims=False): arr = asanyarray(a) + is_float16_result = False rcount = _count_reduce_items(arr, axis) - orig_dtype = dtype # Make this warning show up first if rcount == 0: warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2) @@ -65,16 +65,17 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False): dtype = mu.dtype('f8') elif issubclass(arr.dtype.type, nt.float16): dtype = mu.dtype('f4') + is_float16_result = True ret = umr_sum(arr, axis, dtype, out, keepdims) if isinstance(ret, mu.ndarray): ret = um.true_divide( ret, rcount, out=ret, casting='unsafe', subok=False) - if orig_dtype is None and issubclass(arr.dtype.type, nt.float16): + if is_float16_result and out is None: ret = a.dtype.type(ret) elif hasattr(ret, 'dtype'): - if orig_dtype is None and issubclass(arr.dtype.type, nt.float16): - ret = a.dtype.type(ret / rcount) + if is_float16_result: + ret = nt.float16(ret / rcount) else: ret = ret.dtype.type(ret / rcount) else: |