summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2012-11-13 17:48:55 +0100
committerOndřej Čertík <ondrej.certik@gmail.com>2012-12-06 14:22:57 -0800
commit7f9d7bcf616371f3318513ca5500f19124f573e1 (patch)
treeaf4824181e944ffee89d0ba3ffe55049b180cbea /numpy/random/tests
parent3e5b9b2366607bc3d85b60c34b370327e491e1ef (diff)
downloadnumpy-7f9d7bcf616371f3318513ca5500f19124f573e1.tar.gz
TST: Add tests for new feature and fix in random.choice
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_random.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index ee40cce69..e8875e578 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -138,7 +138,7 @@ class TestRandomDist(TestCase):
np.random.seed(self.seed)
actual = np.random.choice(4, 3, replace=False,
p=[0.1, 0.3, 0.5, 0.1])
- desired = np.array([2, 1, 3])
+ desired = np.array([2, 3, 1])
np.testing.assert_array_equal(actual, desired)
def test_choice_noninteger(self):
@@ -161,6 +161,32 @@ class TestRandomDist(TestCase):
assert_raises(ValueError, sample, [1,2,3], 2, replace=False,
p=[1,0,0])
+ def test_choice_return_shape(self):
+ p = [0.1,0.9]
+ # Check scalar
+ assert_(np.isscalar(np.random.choice(2, replace=True)))
+ assert_(np.isscalar(np.random.choice(2, replace=False)))
+ assert_(np.isscalar(np.random.choice(2, replace=True, p=p)))
+ assert_(np.isscalar(np.random.choice(2, replace=False, p=p)))
+ assert_(np.isscalar(np.random.choice([1,2], replace=True)))
+
+ # Check 0-d array
+ s = tuple()
+ assert_(not np.isscalar(np.random.choice(2, s, replace=True)))
+ assert_(not np.isscalar(np.random.choice(2, s, replace=False)))
+ assert_(not np.isscalar(np.random.choice(2, s, replace=True, p=p)))
+ assert_(not np.isscalar(np.random.choice(2, s, replace=False, p=p)))
+ assert_(not np.isscalar(np.random.choice([1,2], s, replace=True)))
+
+ # Check multi dimensional array
+ s = (2,3)
+ p = [0.1, 0.1, 0.1, 0.1, 0.4, 0.2]
+ assert_(np.random.choice(6, s, replace=True).shape, s)
+ assert_(np.random.choice(6, s, replace=False).shape, s)
+ assert_(np.random.choice(6, s, replace=True, p=p).shape, s)
+ assert_(np.random.choice(6, s, replace=False, p=p).shape, s)
+ assert_(np.random.choice(np.arange(6), s, replace=True).shape, s)
+
def test_bytes(self):
np.random.seed(self.seed)
actual = np.random.bytes(10)