diff options
author | Matti Picus <matti.picus@gmail.com> | 2019-12-03 12:11:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-03 12:11:37 +0200 |
commit | b03fab8dad4a165e25739f2081e3936b522554ac (patch) | |
tree | 201fe669886ee0733d2c4e6a4d5535499bbdfa3d | |
parent | fc860a2b279d5c370e2b332995b33528ebd97deb (diff) | |
parent | 2b51aa217bb7577f6c43c26f1156d9fc29536f96 (diff) | |
download | numpy-b03fab8dad4a165e25739f2081e3936b522554ac.tar.gz |
Merge pull request #15023 from qwhelan/nan_perf
MAINT: Only copy input array in _replace_nan() if there are nans to replace
-rw-r--r-- | numpy/lib/nanfunctions.py | 3 | ||||
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 29 |
2 files changed, 30 insertions, 2 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py index 457cca146..8e2a34e70 100644 --- a/numpy/lib/nanfunctions.py +++ b/numpy/lib/nanfunctions.py @@ -95,7 +95,7 @@ def _replace_nan(a, val): NaNs, otherwise return None. """ - a = np.array(a, subok=True, copy=True) + a = np.asanyarray(a) if a.dtype == np.object_: # object arrays do not support `isnan` (gh-9009), so make a guess @@ -106,6 +106,7 @@ def _replace_nan(a, val): mask = None if mask is not None: + a = np.array(a, subok=True, copy=True) np.copyto(a, val, where=mask) return a, mask diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index b7261c63f..da2d0cc52 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -4,7 +4,7 @@ import warnings import pytest import numpy as np -from numpy.lib.nanfunctions import _nan_mask +from numpy.lib.nanfunctions import _nan_mask, _replace_nan from numpy.testing import ( assert_, assert_equal, assert_almost_equal, assert_no_warnings, assert_raises, assert_array_equal, suppress_warnings @@ -953,3 +953,30 @@ def test__nan_mask(arr, expected): # for types that can't possibly contain NaN if type(expected) is not np.ndarray: assert actual is True + + +def test__replace_nan(): + """ Test that _replace_nan returns the original array if there are no + NaNs, not a copy. + """ + for dtype in [np.bool, np.int32, np.int64]: + arr = np.array([0, 1], dtype=dtype) + result, mask = _replace_nan(arr, 0) + assert mask is None + # do not make a copy if there are no nans + assert result is arr + + for dtype in [np.float32, np.float64]: + arr = np.array([0, 1], dtype=dtype) + result, mask = _replace_nan(arr, 2) + assert (mask == False).all() + # mask is not None, so we make a copy + assert result is not arr + assert_equal(result, arr) + + arr_nan = np.array([0, 1, np.nan], dtype=dtype) + result_nan, mask_nan = _replace_nan(arr_nan, 2) + assert_equal(mask_nan, np.array([False, False, True])) + assert result_nan is not arr_nan + assert_equal(result_nan, np.array([0, 1, 2])) + assert np.isnan(arr_nan[-1]) |