summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests/test_linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r--numpy/linalg/tests/test_linalg.py57
1 files changed, 31 insertions, 26 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index b1dbd4c22..17ee40042 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -589,11 +589,12 @@ class TestEigvals(EigvalsCases):
class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
def do(self, a, b, tags):
- evalues, evectors = linalg.eig(a)
- assert_allclose(dot_generalized(a, evectors),
- np.asarray(evectors) * np.asarray(evalues)[..., None, :],
- rtol=get_rtol(evalues.dtype))
- assert_(consistent_subclass(evectors, a))
+ res = linalg.eig(a)
+ eigenvalues, eigenvectors = res.eigenvalues, res.eigenvectors
+ assert_allclose(dot_generalized(a, eigenvectors),
+ np.asarray(eigenvectors) * np.asarray(eigenvalues)[..., None, :],
+ rtol=get_rtol(eigenvalues.dtype))
+ assert_(consistent_subclass(eigenvectors, a))
class TestEig(EigCases):
@@ -638,10 +639,11 @@ class SVDBaseTests:
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
def test_types(self, dtype):
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- u, s, vh = linalg.svd(x)
- assert_equal(u.dtype, dtype)
- assert_equal(s.dtype, get_real_dtype(dtype))
- assert_equal(vh.dtype, dtype)
+ res = linalg.svd(x)
+ U, S, Vh = res.U, res.S, res.Vh
+ assert_equal(U.dtype, dtype)
+ assert_equal(S.dtype, get_real_dtype(dtype))
+ assert_equal(Vh.dtype, dtype)
s = linalg.svd(x, compute_uv=False, hermitian=self.hermitian)
assert_equal(s.dtype, get_real_dtype(dtype))
@@ -844,7 +846,8 @@ class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
def do(self, a, b, tags):
d = linalg.det(a)
- (s, ld) = linalg.slogdet(a)
+ res = linalg.slogdet(a)
+ s, ld = res.sign, res.logabsdet
if asarray(a).dtype.type in (single, double):
ad = asarray(a).astype(double)
else:
@@ -1144,7 +1147,8 @@ class TestEighCases(HermitianTestCase, HermitianGeneralizedTestCase):
def do(self, a, b, tags):
# note that eigenvalue arrays returned by eig must be sorted since
# their order isn't guaranteed.
- ev, evc = linalg.eigh(a)
+ res = linalg.eigh(a)
+ ev, evc = res.eigenvalues, res.eigenvectors
evalues, evectors = linalg.eig(a)
evalues.sort(axis=-1)
assert_almost_equal(ev, evalues)
@@ -1632,16 +1636,17 @@ class TestQR:
k = min(m, n)
# mode == 'complete'
- q, r = linalg.qr(a, mode='complete')
- assert_(q.dtype == a_dtype)
- assert_(r.dtype == a_dtype)
- assert_(isinstance(q, a_type))
- assert_(isinstance(r, a_type))
- assert_(q.shape == (m, m))
- assert_(r.shape == (m, n))
- assert_almost_equal(dot(q, r), a)
- assert_almost_equal(dot(q.T.conj(), q), np.eye(m))
- assert_almost_equal(np.triu(r), r)
+ res = linalg.qr(a, mode='complete')
+ Q, R = res.Q, res.R
+ assert_(Q.dtype == a_dtype)
+ assert_(R.dtype == a_dtype)
+ assert_(isinstance(Q, a_type))
+ assert_(isinstance(R, a_type))
+ assert_(Q.shape == (m, m))
+ assert_(R.shape == (m, n))
+ assert_almost_equal(dot(Q, R), a)
+ assert_almost_equal(dot(Q.T.conj(), Q), np.eye(m))
+ assert_almost_equal(np.triu(R), R)
# mode == 'reduced'
q1, r1 = linalg.qr(a, mode='reduced')
@@ -1736,7 +1741,7 @@ class TestQR:
assert_(r.shape[-2:] == (m, n))
assert_almost_equal(matmul(q, r), a)
I_mat = np.identity(q.shape[-1])
- stack_I_mat = np.broadcast_to(I_mat,
+ stack_I_mat = np.broadcast_to(I_mat,
q.shape[:-2] + (q.shape[-1],)*2)
assert_almost_equal(matmul(swapaxes(q, -1, -2).conj(), q), stack_I_mat)
assert_almost_equal(np.triu(r[..., :, :]), r)
@@ -1751,9 +1756,9 @@ class TestQR:
assert_(r1.shape[-2:] == (k, n))
assert_almost_equal(matmul(q1, r1), a)
I_mat = np.identity(q1.shape[-1])
- stack_I_mat = np.broadcast_to(I_mat,
+ stack_I_mat = np.broadcast_to(I_mat,
q1.shape[:-2] + (q1.shape[-1],)*2)
- assert_almost_equal(matmul(swapaxes(q1, -1, -2).conj(), q1),
+ assert_almost_equal(matmul(swapaxes(q1, -1, -2).conj(), q1),
stack_I_mat)
assert_almost_equal(np.triu(r1[..., :, :]), r1)
@@ -1764,12 +1769,12 @@ class TestQR:
assert_almost_equal(r2, r1)
@pytest.mark.parametrize("size", [
- (3, 4), (4, 3), (4, 4),
+ (3, 4), (4, 3), (4, 4),
(3, 0), (0, 3)])
@pytest.mark.parametrize("outer_size", [
(2, 2), (2,), (2, 3, 4)])
@pytest.mark.parametrize("dt", [
- np.single, np.double,
+ np.single, np.double,
np.csingle, np.cdouble])
def test_stacked_inputs(self, outer_size, size, dt):