diff options
| author | Sebastian Berg <sebastian@sipsolutions.net> | 2012-11-13 17:48:55 +0100 |
|---|---|---|
| committer | Ondřej Čertík <ondrej.certik@gmail.com> | 2012-12-06 14:22:57 -0800 |
| commit | 7f9d7bcf616371f3318513ca5500f19124f573e1 (patch) | |
| tree | af4824181e944ffee89d0ba3ffe55049b180cbea /numpy/random/tests | |
| parent | 3e5b9b2366607bc3d85b60c34b370327e491e1ef (diff) | |
| download | numpy-7f9d7bcf616371f3318513ca5500f19124f573e1.tar.gz | |
TST: Add tests for new feature and fix in random.choice
Diffstat (limited to 'numpy/random/tests')
| -rw-r--r-- | numpy/random/tests/test_random.py | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index ee40cce69..e8875e578 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -138,7 +138,7 @@ class TestRandomDist(TestCase): np.random.seed(self.seed) actual = np.random.choice(4, 3, replace=False, p=[0.1, 0.3, 0.5, 0.1]) - desired = np.array([2, 1, 3]) + desired = np.array([2, 3, 1]) np.testing.assert_array_equal(actual, desired) def test_choice_noninteger(self): @@ -161,6 +161,32 @@ class TestRandomDist(TestCase): assert_raises(ValueError, sample, [1,2,3], 2, replace=False, p=[1,0,0]) + def test_choice_return_shape(self): + p = [0.1,0.9] + # Check scalar + assert_(np.isscalar(np.random.choice(2, replace=True))) + assert_(np.isscalar(np.random.choice(2, replace=False))) + assert_(np.isscalar(np.random.choice(2, replace=True, p=p))) + assert_(np.isscalar(np.random.choice(2, replace=False, p=p))) + assert_(np.isscalar(np.random.choice([1,2], replace=True))) + + # Check 0-d array + s = tuple() + assert_(not np.isscalar(np.random.choice(2, s, replace=True))) + assert_(not np.isscalar(np.random.choice(2, s, replace=False))) + assert_(not np.isscalar(np.random.choice(2, s, replace=True, p=p))) + assert_(not np.isscalar(np.random.choice(2, s, replace=False, p=p))) + assert_(not np.isscalar(np.random.choice([1,2], s, replace=True))) + + # Check multi dimensional array + s = (2,3) + p = [0.1, 0.1, 0.1, 0.1, 0.4, 0.2] + assert_(np.random.choice(6, s, replace=True).shape, s) + assert_(np.random.choice(6, s, replace=False).shape, s) + assert_(np.random.choice(6, s, replace=True, p=p).shape, s) + assert_(np.random.choice(6, s, replace=False, p=p).shape, s) + assert_(np.random.choice(np.arange(6), s, replace=True).shape, s) + def test_bytes(self): np.random.seed(self.seed) actual = np.random.bytes(10) |
