summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorMatteo Raso <mraso@uoguelph.ca>2022-08-25 23:09:46 -0400
committerMatteo Raso <mraso@uoguelph.ca>2022-08-25 23:09:46 -0400
commit1647f467a7fcb000ffeff0f2fd4b30f23a32d882 (patch)
tree2795bba990ded4affde724d72ffbc42c033e1e48 /numpy/random
parent50a74fb65fc752e77a2f9e9e2b7227629c2ba953 (diff)
downloadnumpy-1647f467a7fcb000ffeff0f2fd4b30f23a32d882.tar.gz
TST: Implemented an unused test for np.random.randint
In numpy/random/tests/test_random.py, a class called TestSingleEltArrayInput had a method called test_randint that was commented out, with the instructions to uncomment it once np.random.randint was able to broadcast arguments. Since np.random.randint has been able to broadcast arguments for a while now, I uncommented the test. The only modification I made to the code was fixing a small error, where the author incorrectly tried to call "assert_equal" as a method of the TestSingleEltArrayInput instead of a function that was imported from numpy.testing. I ran runtests.py, and the new test passed.
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/tests/test_random.py33
1 files changed, 16 insertions, 17 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index 773b63653..ecda5dab5 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -1712,23 +1712,22 @@ class TestSingleEltArrayInput:
out = func(self.argOne, argTwo[0])
assert_equal(out.shape, self.tgtShape)
-# TODO: Uncomment once randint can broadcast arguments
-# def test_randint(self):
-# itype = [bool, np.int8, np.uint8, np.int16, np.uint16,
-# np.int32, np.uint32, np.int64, np.uint64]
-# func = np.random.randint
-# high = np.array([1])
-# low = np.array([0])
-#
-# for dt in itype:
-# out = func(low, high, dtype=dt)
-# self.assert_equal(out.shape, self.tgtShape)
-#
-# out = func(low[0], high, dtype=dt)
-# self.assert_equal(out.shape, self.tgtShape)
-#
-# out = func(low, high[0], dtype=dt)
-# self.assert_equal(out.shape, self.tgtShape)
+ def test_randint(self):
+ itype = [bool, np.int8, np.uint8, np.int16, np.uint16,
+ np.int32, np.uint32, np.int64, np.uint64]
+ func = np.random.randint
+ high = np.array([1])
+ low = np.array([0])
+
+ for dt in itype:
+ out = func(low, high, dtype=dt)
+ assert_equal(out.shape, self.tgtShape)
+
+ out = func(low[0], high, dtype=dt)
+ assert_equal(out.shape, self.tgtShape)
+
+ out = func(low, high[0], dtype=dt)
+ assert_equal(out.shape, self.tgtShape)
def test_three_arg_funcs(self):
funcs = [np.random.noncentral_f, np.random.triangular,