summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-09-08 10:30:55 -0500
committerGitHub <noreply@github.com>2018-09-08 10:30:55 -0500
commit1b400ea346db432a20994b9848a88a5f42567263 (patch)
tree2c81a23f65bcc735c74ead860cb870516969f01c
parentc16396374ca633da49621666582275fca8ccfe27 (diff)
parentb3125bae72efe2bd7d0c40785e4511584ee6d4dc (diff)
downloadnumpy-1b400ea346db432a20994b9848a88a5f42567263.tar.gz
Merge pull request #11904 from QuLogic/core-parametrize
Use pytest for some already-parametrized core tests
-rw-r--r--numpy/core/tests/test_multiarray.py73
-rw-r--r--numpy/core/tests/test_print.py71
2 files changed, 49 insertions, 95 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 1c59abaa7..209b3c533 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4133,15 +4133,12 @@ class TestPutmask(object):
def test_mask_size(self):
assert_raises(ValueError, np.putmask, np.array([1, 2, 3]), [True], 5)
- def tst_byteorder(self, dtype):
+ @pytest.mark.parametrize('dtype', ('>i4', '<i4'))
+ def test_byteorder(self, dtype):
x = np.array([1, 2, 3], dtype)
np.putmask(x, [True, False, True], -1)
assert_array_equal(x, [-1, 2, -1])
- def test_ip_byteorder(self):
- for dtype in ('>i4', '<i4'):
- self.tst_byteorder(dtype)
-
def test_record_array(self):
# Note mixed byteorder.
rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
@@ -4191,14 +4188,11 @@ class TestTake(object):
assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0])
assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1])
- def tst_byteorder(self, dtype):
+ @pytest.mark.parametrize('dtype', ('>i4', '<i4'))
+ def test_byteorder(self, dtype):
x = np.array([1, 2, 3], dtype)
assert_array_equal(x.take([0, 2, 1]), [1, 3, 2])
- def test_ip_byteorder(self):
- for dtype in ('>i4', '<i4'):
- self.tst_byteorder(dtype)
-
def test_record_array(self):
# Note mixed byteorder.
rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
@@ -4574,19 +4568,16 @@ class TestIO(object):
class TestFromBuffer(object):
- def tst_basic(self, buffer, expected, kwargs):
- assert_array_equal(np.frombuffer(buffer,**kwargs), expected)
-
- def test_ip_basic(self):
- for byteorder in ['<', '>']:
- for dtype in [float, int, complex]:
- dt = np.dtype(dtype).newbyteorder(byteorder)
- x = (np.random.random((4, 7))*5).astype(dt)
- buf = x.tobytes()
- self.tst_basic(buf, x.flat, {'dtype':dt})
+ @pytest.mark.parametrize('byteorder', ['<', '>'])
+ @pytest.mark.parametrize('dtype', [float, int, complex])
+ def test_basic(self, byteorder, dtype):
+ dt = np.dtype(dtype).newbyteorder(byteorder)
+ x = (np.random.random((4, 7)) * 5).astype(dt)
+ buf = x.tobytes()
+ assert_array_equal(np.frombuffer(buf, dtype=dt), x.flat)
def test_empty(self):
- self.tst_basic(b'', np.array([]), {})
+ assert_array_equal(np.frombuffer(b''), np.array([]))
class TestFlat(object):
@@ -5940,9 +5931,10 @@ class TestRepeat(object):
NEIGH_MODE = {'zero': 0, 'one': 1, 'constant': 2, 'circular': 3, 'mirror': 4}
+@pytest.mark.parametrize('dt', [float, Decimal], ids=['float', 'object'])
class TestNeighborhoodIter(object):
# Simple, 2d tests
- def _test_simple2d(self, dt):
+ def test_simple2d(self, dt):
# Test zero and one padding for simple data type
x = np.array([[0, 1], [2, 3]], dtype=dt)
r = [np.array([[0, 0, 0], [0, 0, 1]], dtype=dt),
@@ -5969,13 +5961,7 @@ class TestNeighborhoodIter(object):
x, [-1, 0, -1, 1], 4, NEIGH_MODE['constant'])
assert_array_equal(l, r)
- def test_simple2d(self):
- self._test_simple2d(float)
-
- def test_simple2d_object(self):
- self._test_simple2d(Decimal)
-
- def _test_mirror2d(self, dt):
+ def test_mirror2d(self, dt):
x = np.array([[0, 1], [2, 3]], dtype=dt)
r = [np.array([[0, 0, 1], [0, 0, 1]], dtype=dt),
np.array([[0, 1, 1], [0, 1, 1]], dtype=dt),
@@ -5985,14 +5971,8 @@ class TestNeighborhoodIter(object):
x, [-1, 0, -1, 1], x[0], NEIGH_MODE['mirror'])
assert_array_equal(l, r)
- def test_mirror2d(self):
- self._test_mirror2d(float)
-
- def test_mirror2d_object(self):
- self._test_mirror2d(Decimal)
-
# Simple, 1d tests
- def _test_simple(self, dt):
+ def test_simple(self, dt):
# Test padding with constant values
x = np.linspace(1, 5, 5).astype(dt)
r = [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 0]]
@@ -6010,14 +5990,8 @@ class TestNeighborhoodIter(object):
x, [-1, 1], x[4], NEIGH_MODE['constant'])
assert_array_equal(l, r)
- def test_simple_float(self):
- self._test_simple(float)
-
- def test_simple_object(self):
- self._test_simple(Decimal)
-
# Test mirror modes
- def _test_mirror(self, dt):
+ def test_mirror(self, dt):
x = np.linspace(1, 5, 5).astype(dt)
r = np.array([[2, 1, 1, 2, 3], [1, 1, 2, 3, 4], [1, 2, 3, 4, 5],
[2, 3, 4, 5, 5], [3, 4, 5, 5, 4]], dtype=dt)
@@ -6026,14 +6000,8 @@ class TestNeighborhoodIter(object):
assert_([i.dtype == dt for i in l])
assert_array_equal(l, r)
- def test_mirror(self):
- self._test_mirror(float)
-
- def test_mirror_object(self):
- self._test_mirror(Decimal)
-
# Circular mode
- def _test_circular(self, dt):
+ def test_circular(self, dt):
x = np.linspace(1, 5, 5).astype(dt)
r = np.array([[4, 5, 1, 2, 3], [5, 1, 2, 3, 4], [1, 2, 3, 4, 5],
[2, 3, 4, 5, 1], [3, 4, 5, 1, 2]], dtype=dt)
@@ -6041,11 +6009,6 @@ class TestNeighborhoodIter(object):
x, [-2, 2], x[0], NEIGH_MODE['circular'])
assert_array_equal(l, r)
- def test_circular(self):
- self._test_circular(float)
-
- def test_circular_object(self):
- self._test_circular(Decimal)
# Test stacking neighborhood iterators
class TestStackedNeighborhoodIter(object):
diff --git a/numpy/core/tests/test_print.py b/numpy/core/tests/test_print.py
index 433208748..77679424c 100644
--- a/numpy/core/tests/test_print.py
+++ b/numpy/core/tests/test_print.py
@@ -2,6 +2,8 @@ from __future__ import division, absolute_import, print_function
import sys
+import pytest
+
import numpy as np
from numpy.testing import assert_, assert_equal, SkipTest
from numpy.core.tests._locales import CommaDecimalPointLocale
@@ -15,7 +17,15 @@ else:
_REF = {np.inf: 'inf', -np.inf: '-inf', np.nan: 'nan'}
-def check_float_type(tp):
+@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
+def test_float_types(tp):
+ """ Check formatting.
+
+ This is only for the str function, and only for simple types.
+ The precision of np.float32 and np.longdouble aren't the same as the
+ python float precision.
+
+ """
for x in [0, 1, -1, 1e20]:
assert_equal(str(tp(x)), str(float(x)),
err_msg='Failed str formatting for type %s' % tp)
@@ -28,34 +38,30 @@ def check_float_type(tp):
assert_equal(str(tp(1e16)), ref,
err_msg='Failed str formatting for type %s' % tp)
-def test_float_types():
- """ Check formatting.
+
+@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
+def test_nan_inf_float(tp):
+ """ Check formatting of nan & inf.
This is only for the str function, and only for simple types.
The precision of np.float32 and np.longdouble aren't the same as the
python float precision.
"""
- for t in [np.float32, np.double, np.longdouble]:
- check_float_type(t)
-
-def check_nan_inf_float(tp):
for x in [np.inf, -np.inf, np.nan]:
assert_equal(str(tp(x)), _REF[x],
err_msg='Failed str formatting for type %s' % tp)
-def test_nan_inf_float():
- """ Check formatting of nan & inf.
+
+@pytest.mark.parametrize('tp', [np.complex64, np.cdouble, np.clongdouble])
+def test_complex_types(tp):
+ """Check formatting of complex types.
This is only for the str function, and only for simple types.
The precision of np.float32 and np.longdouble aren't the same as the
python float precision.
"""
- for t in [np.float32, np.double, np.longdouble]:
- check_nan_inf_float(t)
-
-def check_complex_type(tp):
for x in [0, 1, -1, 1e20]:
assert_equal(str(tp(x)), str(complex(x)),
err_msg='Failed str formatting for type %s' % tp)
@@ -72,18 +78,9 @@ def check_complex_type(tp):
assert_equal(str(tp(1e16)), ref,
err_msg='Failed str formatting for type %s' % tp)
-def test_complex_types():
- """Check formatting of complex types.
-
- This is only for the str function, and only for simple types.
- The precision of np.float32 and np.longdouble aren't the same as the
- python float precision.
-
- """
- for t in [np.complex64, np.cdouble, np.clongdouble]:
- check_complex_type(t)
-def test_complex_inf_nan():
+@pytest.mark.parametrize('dtype', [np.complex64, np.cdouble, np.clongdouble])
+def test_complex_inf_nan(dtype):
"""Check inf/nan formatting of complex types."""
TESTS = {
complex(np.inf, 0): "(inf+0j)",
@@ -103,12 +100,9 @@ def test_complex_inf_nan():
complex(-np.nan, 1): "(nan+1j)",
complex(1, -np.nan): "(1+nanj)",
}
- for tp in [np.complex64, np.cdouble, np.clongdouble]:
- for c, s in TESTS.items():
- _check_complex_inf_nan(c, s, tp)
+ for c, s in TESTS.items():
+ assert_equal(str(dtype(c)), s)
-def _check_complex_inf_nan(c, s, dtype):
- assert_equal(str(dtype(c)), s)
# print tests
def _test_redirected_print(x, tp, ref=None):
@@ -129,7 +123,10 @@ def _test_redirected_print(x, tp, ref=None):
assert_equal(file.getvalue(), file_tp.getvalue(),
err_msg='print failed for type%s' % tp)
-def check_float_type_print(tp):
+
+@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
+def test_float_type_print(tp):
+ """Check formatting when using print """
for x in [0, 1, -1, 1e20]:
_test_redirected_print(float(x), tp)
@@ -142,7 +139,10 @@ def check_float_type_print(tp):
ref = '1e+16'
_test_redirected_print(float(1e16), tp, ref)
-def check_complex_type_print(tp):
+
+@pytest.mark.parametrize('tp', [np.complex64, np.cdouble, np.clongdouble])
+def test_complex_type_print(tp):
+ """Check formatting when using print """
# We do not create complex with inf/nan directly because the feature is
# missing in python < 2.6
for x in [0, 1, -1, 1e20]:
@@ -158,15 +158,6 @@ def check_complex_type_print(tp):
_test_redirected_print(complex(-np.inf, 1), tp, '(-inf+1j)')
_test_redirected_print(complex(-np.nan, 1), tp, '(nan+1j)')
-def test_float_type_print():
- """Check formatting when using print """
- for t in [np.float32, np.double, np.longdouble]:
- check_float_type_print(t)
-
-def test_complex_type_print():
- """Check formatting when using print """
- for t in [np.complex64, np.cdouble, np.clongdouble]:
- check_complex_type_print(t)
def test_scalar_format():
"""Test the str.format method with NumPy scalar types"""