diff options
| author | Max Balandat <balandat@fb.com> | 2020-03-30 19:24:03 -0700 |
|---|---|---|
| committer | Max Balandat <balandat@fb.com> | 2020-04-04 18:27:31 -0700 |
| commit | 72457f01832d10c72a1839aafc178cf4f53449cb (patch) | |
| tree | 8872abe06a6f860e4f9099795ee9900cb66683b4 /numpy/random/tests | |
| parent | f0a74b2d4e50a1b4f9d9189c6b2e31b920913a8b (diff) | |
| download | numpy-72457f01832d10c72a1839aafc178cf4f53449cb.tar.gz | |
Bug: Fix eigh mnd cholesky methods of numpy.random.multivariate_normal
Fixes #15871
Diffstat (limited to 'numpy/random/tests')
| -rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 6f4407373..b10c1310e 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -1242,6 +1242,17 @@ class TestRandomDist: assert_raises(ValueError, random.multivariate_normal, mean, cov, check_valid='raise', method='eigh') + # check degenerate samples from singular covariance matrix + cov = [[1, 1], [1, 1]] + if method in ('svd', 'eigh'): + samples = random.multivariate_normal(mean, cov, size=(3, 2), + method=method) + assert_array_almost_equal(samples[..., 0], samples[..., 1], + decimal=6) + else: + assert_raises(LinAlgError, random.multivariate_normal, mean, cov, + method='cholesky') + cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32) with suppress_warnings() as sup: random.multivariate_normal(mean, cov, method=method) @@ -1259,6 +1270,19 @@ class TestRandomDist: assert_raises(ValueError, random.multivariate_normal, mu, np.eye(3)) + @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"]) + def test_multivariate_normal_basic_stats(self, method): + random = Generator(MT19937(self.seed)) + n_s = 1000 + mean = np.array([1, 2]) + cov = np.array([[2, 1], [1, 2]]) + s = random.multivariate_normal(mean, cov, size=(n_s,), method=method) + s_center = s - mean + cov_emp = (s_center.T @ s_center) / (n_s - 1) + # these are pretty loose and are only designed to detect major errors + assert np.all(np.abs(s_center.mean(-2)) < 0.1) + assert np.all(np.abs(cov_emp - cov) < 0.2) + def test_negative_binomial(self): random = Generator(MT19937(self.seed)) actual = random.negative_binomial(n=100, p=.12345, size=(3, 2)) |
