summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-10-04 09:30:36 -0600
committerGitHub <noreply@github.com>2021-10-04 09:30:36 -0600
commit7dabf22f1ea47633f4c7fb2848ae0612742687c8 (patch)
tree5385fcba422feaff4c97009269faa8206c246b02
parente7842a29b0eed67f131d49143525145a9639c935 (diff)
parent476903f189c01cc9b0480b10d80bfca1070ee442 (diff)
downloadnumpy-7dabf22f1ea47633f4c7fb2848ae0612742687c8.tar.gz
Merge pull request #20008 from BvB93/window-func
BUG: Fix the `lib.function_base` window functions ignoring extended precision float dtypes
-rw-r--r--numpy/lib/function_base.py18
-rw-r--r--numpy/lib/tests/test_function_base.py107
2 files changed, 103 insertions, 22 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 80eaf8acf..3ca566f73 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -2804,9 +2804,9 @@ def blackman(M):
"""
if M < 1:
- return array([])
+ return array([], dtype=np.result_type(M, 0.0))
if M == 1:
- return ones(1, float)
+ return ones(1, dtype=np.result_type(M, 0.0))
n = arange(1-M, M, 2)
return 0.42 + 0.5*cos(pi*n/(M-1)) + 0.08*cos(2.0*pi*n/(M-1))
@@ -2913,9 +2913,9 @@ def bartlett(M):
"""
if M < 1:
- return array([])
+ return array([], dtype=np.result_type(M, 0.0))
if M == 1:
- return ones(1, float)
+ return ones(1, dtype=np.result_type(M, 0.0))
n = arange(1-M, M, 2)
return where(less_equal(n, 0), 1 + n/(M-1), 1 - n/(M-1))
@@ -3017,9 +3017,9 @@ def hanning(M):
"""
if M < 1:
- return array([])
+ return array([], dtype=np.result_type(M, 0.0))
if M == 1:
- return ones(1, float)
+ return ones(1, dtype=np.result_type(M, 0.0))
n = arange(1-M, M, 2)
return 0.5 + 0.5*cos(pi*n/(M-1))
@@ -3117,9 +3117,9 @@ def hamming(M):
"""
if M < 1:
- return array([])
+ return array([], dtype=np.result_type(M, 0.0))
if M == 1:
- return ones(1, float)
+ return ones(1, dtype=np.result_type(M, 0.0))
n = arange(1-M, M, 2)
return 0.54 + 0.46*cos(pi*n/(M-1))
@@ -3396,7 +3396,7 @@ def kaiser(M, beta):
"""
if M == 1:
- return np.array([1.])
+ return np.ones(1, dtype=np.result_type(M, 0.0))
n = arange(0, M)
alpha = (M-1)/2.0
return i0(beta * sqrt(1-((n-alpha)/alpha)**2.0))/i0(float(beta))
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 5f27ea655..66110b479 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1528,7 +1528,7 @@ class TestVectorize:
([('x',)], [('y',), ()]))
assert_equal(nfb._parse_gufunc_signature('(),(a,b,c),(d)->(d,e)'),
([(), ('a', 'b', 'c'), ('d',)], [('d', 'e')]))
-
+
# Tests to check if whitespaces are ignored
assert_equal(nfb._parse_gufunc_signature('(x )->()'), ([('x',)], [()]))
assert_equal(nfb._parse_gufunc_signature('( x , y )->( )'),
@@ -1853,35 +1853,116 @@ class TestUnwrap:
assert sm_discont.dtype == wrap_uneven.dtype
+@pytest.mark.parametrize(
+ "dtype", "O" + np.typecodes["AllInteger"] + np.typecodes["Float"]
+)
+@pytest.mark.parametrize("M", [0, 1, 10])
class TestFilterwindows:
- def test_hanning(self):
+ def test_hanning(self, dtype: str, M: int) -> None:
+ scalar = np.array(M, dtype=dtype)[()]
+
+ w = hanning(scalar)
+ if dtype == "O":
+ ref_dtype = np.float64
+ else:
+ ref_dtype = np.result_type(scalar.dtype, np.float64)
+ assert w.dtype == ref_dtype
+
# check symmetry
- w = hanning(10)
assert_equal(w, flipud(w))
+
# check known value
- assert_almost_equal(np.sum(w, axis=0), 4.500, 4)
+ if scalar < 1:
+ assert_array_equal(w, np.array([]))
+ elif scalar == 1:
+ assert_array_equal(w, np.ones(1))
+ else:
+ assert_almost_equal(np.sum(w, axis=0), 4.500, 4)
+
+ def test_hamming(self, dtype: str, M: int) -> None:
+ scalar = np.array(M, dtype=dtype)[()]
+
+ w = hamming(scalar)
+ if dtype == "O":
+ ref_dtype = np.float64
+ else:
+ ref_dtype = np.result_type(scalar.dtype, np.float64)
+ assert w.dtype == ref_dtype
+
+ # check symmetry
+ assert_equal(w, flipud(w))
+
+ # check known value
+ if scalar < 1:
+ assert_array_equal(w, np.array([]))
+ elif scalar == 1:
+ assert_array_equal(w, np.ones(1))
+ else:
+ assert_almost_equal(np.sum(w, axis=0), 4.9400, 4)
+
+ def test_bartlett(self, dtype: str, M: int) -> None:
+ scalar = np.array(M, dtype=dtype)[()]
+
+ w = bartlett(scalar)
+ if dtype == "O":
+ ref_dtype = np.float64
+ else:
+ ref_dtype = np.result_type(scalar.dtype, np.float64)
+ assert w.dtype == ref_dtype
- def test_hamming(self):
# check symmetry
- w = hamming(10)
assert_equal(w, flipud(w))
+
# check known value
- assert_almost_equal(np.sum(w, axis=0), 4.9400, 4)
+ if scalar < 1:
+ assert_array_equal(w, np.array([]))
+ elif scalar == 1:
+ assert_array_equal(w, np.ones(1))
+ else:
+ assert_almost_equal(np.sum(w, axis=0), 4.4444, 4)
+
+ def test_blackman(self, dtype: str, M: int) -> None:
+ scalar = np.array(M, dtype=dtype)[()]
+
+ w = blackman(scalar)
+ if dtype == "O":
+ ref_dtype = np.float64
+ else:
+ ref_dtype = np.result_type(scalar.dtype, np.float64)
+ assert w.dtype == ref_dtype
- def test_bartlett(self):
# check symmetry
- w = bartlett(10)
assert_equal(w, flipud(w))
+
# check known value
- assert_almost_equal(np.sum(w, axis=0), 4.4444, 4)
+ if scalar < 1:
+ assert_array_equal(w, np.array([]))
+ elif scalar == 1:
+ assert_array_equal(w, np.ones(1))
+ else:
+ assert_almost_equal(np.sum(w, axis=0), 3.7800, 4)
+
+ def test_kaiser(self, dtype: str, M: int) -> None:
+ scalar = np.array(M, dtype=dtype)[()]
+
+ w = kaiser(scalar, 0)
+ if dtype == "O":
+ ref_dtype = np.float64
+ else:
+ ref_dtype = np.result_type(scalar.dtype, np.float64)
+ assert w.dtype == ref_dtype
- def test_blackman(self):
# check symmetry
- w = blackman(10)
assert_equal(w, flipud(w))
+
# check known value
- assert_almost_equal(np.sum(w, axis=0), 3.7800, 4)
+ if scalar < 1:
+ assert_array_equal(w, np.array([]))
+ elif scalar == 1:
+ assert_array_equal(w, np.ones(1))
+ else:
+ assert_almost_equal(np.sum(w, axis=0), 10, 15)
class TestTrapz: