diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-10-04 09:30:36 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-04 09:30:36 -0600 |
commit | 7dabf22f1ea47633f4c7fb2848ae0612742687c8 (patch) | |
tree | 5385fcba422feaff4c97009269faa8206c246b02 | |
parent | e7842a29b0eed67f131d49143525145a9639c935 (diff) | |
parent | 476903f189c01cc9b0480b10d80bfca1070ee442 (diff) | |
download | numpy-7dabf22f1ea47633f4c7fb2848ae0612742687c8.tar.gz |
Merge pull request #20008 from BvB93/window-func
BUG: Fix the `lib.function_base` window functions ignoring extended precision float dtypes
-rw-r--r-- | numpy/lib/function_base.py | 18 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 107 |
2 files changed, 103 insertions, 22 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 80eaf8acf..3ca566f73 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -2804,9 +2804,9 @@ def blackman(M): """ if M < 1: - return array([]) + return array([], dtype=np.result_type(M, 0.0)) if M == 1: - return ones(1, float) + return ones(1, dtype=np.result_type(M, 0.0)) n = arange(1-M, M, 2) return 0.42 + 0.5*cos(pi*n/(M-1)) + 0.08*cos(2.0*pi*n/(M-1)) @@ -2913,9 +2913,9 @@ def bartlett(M): """ if M < 1: - return array([]) + return array([], dtype=np.result_type(M, 0.0)) if M == 1: - return ones(1, float) + return ones(1, dtype=np.result_type(M, 0.0)) n = arange(1-M, M, 2) return where(less_equal(n, 0), 1 + n/(M-1), 1 - n/(M-1)) @@ -3017,9 +3017,9 @@ def hanning(M): """ if M < 1: - return array([]) + return array([], dtype=np.result_type(M, 0.0)) if M == 1: - return ones(1, float) + return ones(1, dtype=np.result_type(M, 0.0)) n = arange(1-M, M, 2) return 0.5 + 0.5*cos(pi*n/(M-1)) @@ -3117,9 +3117,9 @@ def hamming(M): """ if M < 1: - return array([]) + return array([], dtype=np.result_type(M, 0.0)) if M == 1: - return ones(1, float) + return ones(1, dtype=np.result_type(M, 0.0)) n = arange(1-M, M, 2) return 0.54 + 0.46*cos(pi*n/(M-1)) @@ -3396,7 +3396,7 @@ def kaiser(M, beta): """ if M == 1: - return np.array([1.]) + return np.ones(1, dtype=np.result_type(M, 0.0)) n = arange(0, M) alpha = (M-1)/2.0 return i0(beta * sqrt(1-((n-alpha)/alpha)**2.0))/i0(float(beta)) diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 5f27ea655..66110b479 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1528,7 +1528,7 @@ class TestVectorize: ([('x',)], [('y',), ()])) assert_equal(nfb._parse_gufunc_signature('(),(a,b,c),(d)->(d,e)'), ([(), ('a', 'b', 'c'), ('d',)], [('d', 'e')])) - + # Tests to check if whitespaces are ignored assert_equal(nfb._parse_gufunc_signature('(x )->()'), ([('x',)], [()])) assert_equal(nfb._parse_gufunc_signature('( x , y )->( )'), @@ -1853,35 +1853,116 @@ class TestUnwrap: assert sm_discont.dtype == wrap_uneven.dtype +@pytest.mark.parametrize( + "dtype", "O" + np.typecodes["AllInteger"] + np.typecodes["Float"] +) +@pytest.mark.parametrize("M", [0, 1, 10]) class TestFilterwindows: - def test_hanning(self): + def test_hanning(self, dtype: str, M: int) -> None: + scalar = np.array(M, dtype=dtype)[()] + + w = hanning(scalar) + if dtype == "O": + ref_dtype = np.float64 + else: + ref_dtype = np.result_type(scalar.dtype, np.float64) + assert w.dtype == ref_dtype + # check symmetry - w = hanning(10) assert_equal(w, flipud(w)) + # check known value - assert_almost_equal(np.sum(w, axis=0), 4.500, 4) + if scalar < 1: + assert_array_equal(w, np.array([])) + elif scalar == 1: + assert_array_equal(w, np.ones(1)) + else: + assert_almost_equal(np.sum(w, axis=0), 4.500, 4) + + def test_hamming(self, dtype: str, M: int) -> None: + scalar = np.array(M, dtype=dtype)[()] + + w = hamming(scalar) + if dtype == "O": + ref_dtype = np.float64 + else: + ref_dtype = np.result_type(scalar.dtype, np.float64) + assert w.dtype == ref_dtype + + # check symmetry + assert_equal(w, flipud(w)) + + # check known value + if scalar < 1: + assert_array_equal(w, np.array([])) + elif scalar == 1: + assert_array_equal(w, np.ones(1)) + else: + assert_almost_equal(np.sum(w, axis=0), 4.9400, 4) + + def test_bartlett(self, dtype: str, M: int) -> None: + scalar = np.array(M, dtype=dtype)[()] + + w = bartlett(scalar) + if dtype == "O": + ref_dtype = np.float64 + else: + ref_dtype = np.result_type(scalar.dtype, np.float64) + assert w.dtype == ref_dtype - def test_hamming(self): # check symmetry - w = hamming(10) assert_equal(w, flipud(w)) + # check known value - assert_almost_equal(np.sum(w, axis=0), 4.9400, 4) + if scalar < 1: + assert_array_equal(w, np.array([])) + elif scalar == 1: + assert_array_equal(w, np.ones(1)) + else: + assert_almost_equal(np.sum(w, axis=0), 4.4444, 4) + + def test_blackman(self, dtype: str, M: int) -> None: + scalar = np.array(M, dtype=dtype)[()] + + w = blackman(scalar) + if dtype == "O": + ref_dtype = np.float64 + else: + ref_dtype = np.result_type(scalar.dtype, np.float64) + assert w.dtype == ref_dtype - def test_bartlett(self): # check symmetry - w = bartlett(10) assert_equal(w, flipud(w)) + # check known value - assert_almost_equal(np.sum(w, axis=0), 4.4444, 4) + if scalar < 1: + assert_array_equal(w, np.array([])) + elif scalar == 1: + assert_array_equal(w, np.ones(1)) + else: + assert_almost_equal(np.sum(w, axis=0), 3.7800, 4) + + def test_kaiser(self, dtype: str, M: int) -> None: + scalar = np.array(M, dtype=dtype)[()] + + w = kaiser(scalar, 0) + if dtype == "O": + ref_dtype = np.float64 + else: + ref_dtype = np.result_type(scalar.dtype, np.float64) + assert w.dtype == ref_dtype - def test_blackman(self): # check symmetry - w = blackman(10) assert_equal(w, flipud(w)) + # check known value - assert_almost_equal(np.sum(w, axis=0), 3.7800, 4) + if scalar < 1: + assert_array_equal(w, np.array([])) + elif scalar == 1: + assert_array_equal(w, np.ones(1)) + else: + assert_almost_equal(np.sum(w, axis=0), 10, 15) class TestTrapz: |