summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDWesl <22566757+DWesl@users.noreply.github.com>2021-01-01 13:07:27 -0500
committerDWesl <22566757+DWesl@users.noreply.github.com>2021-01-01 13:07:27 -0500
commit4799b904b759c035041d30e4cf2fe7340aac3955 (patch)
treedf67a413b6b27ab8a3e6bd2d6a9071633983160f
parent31647f1b3e56c2f2a471cd2c3a583311534173f8 (diff)
downloadnumpy-4799b904b759c035041d30e4cf2fe7340aac3955.tar.gz
TST: Turn some tests with loos into parametrized tests.
I wanted to mark only some parts of the loops as xfail for another PR. That part of the PR probably won't make it into numpy, but I think parametrized tests give better information on failure than tests with loops do, so I'm submitting these here.
-rw-r--r--numpy/core/tests/test_multiarray.py38
-rw-r--r--numpy/core/tests/test_scalarmath.py66
2 files changed, 61 insertions, 43 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 048b1688f..624e1aa2d 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -8576,23 +8576,29 @@ def test_equal_override():
assert_equal(array != my_always_equal, 'ne')
-def test_npymath_complex():
+@pytest.mark.parametrize(
+ ["fun", "npfun", "x", "y", "test_dtype"],
+ [
+ pytest.param(
+ fun, npfun, x, y, test_dtype
+ )
+ for (fun, npfun), x, y, test_dtype in itertools.product(
+ [
+ (_multiarray_tests.npy_cabs, np.absolute),
+ (_multiarray_tests.npy_carg, np.angle),
+ ],
+ [1, np.inf, -np.inf, np.nan],
+ [1, np.inf, -np.inf, np.nan],
+ [np.complex64, np.complex128, np.clongdouble],
+ )
+ ],
+)
+def test_npymath_complex(fun, npfun, x, y, test_dtype):
# Smoketest npymath functions
- from numpy.core._multiarray_tests import (
- npy_cabs, npy_carg)
-
- funcs = {npy_cabs: np.absolute,
- npy_carg: np.angle}
- vals = (1, np.inf, -np.inf, np.nan)
- types = (np.complex64, np.complex128, np.clongdouble)
-
- for fun, npfun in funcs.items():
- for x, y in itertools.product(vals, vals):
- for t in types:
- z = t(complex(x, y))
- got = fun(z)
- expected = npfun(z)
- assert_allclose(got, expected)
+ z = test_dtype(complex(x, y))
+ got = fun(z)
+ expected = npfun(z)
+ assert_allclose(got, expected)
def test_npymath_real():
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index d8529418e..5b07e36fa 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -653,33 +653,45 @@ class TestSubtract:
class TestAbs:
- def _test_abs_func(self, absfunc):
- for tp in floating_types + complex_floating_types:
- x = tp(-1.5)
- assert_equal(absfunc(x), 1.5)
- x = tp(0.0)
- res = absfunc(x)
- # assert_equal() checks zero signedness
- assert_equal(res, 0.0)
- x = tp(-0.0)
- res = absfunc(x)
- assert_equal(res, 0.0)
-
- x = tp(np.finfo(tp).max)
- assert_equal(absfunc(x), x.real)
-
- x = tp(np.finfo(tp).tiny)
- assert_equal(absfunc(x), x.real)
-
- x = tp(np.finfo(tp).min)
- assert_equal(absfunc(x), -x.real)
-
- def test_builtin_abs(self):
- self._test_abs_func(abs)
-
- def test_numpy_abs(self):
- self._test_abs_func(np.abs)
-
+ def _test_abs_func(self, absfunc, test_dtype):
+ x = test_dtype(-1.5)
+ assert_equal(absfunc(x), 1.5)
+ x = test_dtype(0.0)
+ res = absfunc(x)
+ # assert_equal() checks zero signedness
+ assert_equal(res, 0.0)
+ x = test_dtype(-0.0)
+ res = absfunc(x)
+ assert_equal(res, 0.0)
+
+ x = test_dtype(np.finfo(test_dtype).max)
+ assert_equal(absfunc(x), x.real)
+
+ x = test_dtype(np.finfo(test_dtype).tiny)
+ assert_equal(absfunc(x), x.real)
+
+ x = test_dtype(np.finfo(test_dtype).min)
+ assert_equal(absfunc(x), -x.real)
+
+ @pytest.mark.parametrize(
+ "dtype",
+ [
+ pytest.param(dtype)
+ for dtype in floating_types + complex_floating_types
+ ],
+ )
+ def test_builtin_abs(self, dtype):
+ self._test_abs_func(abs, dtype)
+
+ @pytest.mark.parametrize(
+ "dtype",
+ [
+ pytest.param(dtype)
+ for dtype in floating_types + complex_floating_types
+ ],
+ )
+ def test_numpy_abs(self, dtype):
+ self._test_abs_func(np.abs, dtype)
class TestBitShifts: