diff options
| author | Matteo Raso <mraso@uoguelph.ca> | 2022-08-25 23:09:46 -0400 |
|---|---|---|
| committer | Matteo Raso <mraso@uoguelph.ca> | 2022-08-25 23:09:46 -0400 |
| commit | 1647f467a7fcb000ffeff0f2fd4b30f23a32d882 (patch) | |
| tree | 2795bba990ded4affde724d72ffbc42c033e1e48 /numpy/random | |
| parent | 50a74fb65fc752e77a2f9e9e2b7227629c2ba953 (diff) | |
| download | numpy-1647f467a7fcb000ffeff0f2fd4b30f23a32d882.tar.gz | |
TST: Implemented an unused test for np.random.randint
In numpy/random/tests/test_random.py, a class called TestSingleEltArrayInput had a method called test_randint that was commented out, with the instructions to uncomment it once np.random.randint was able to broadcast arguments. Since np.random.randint has been able to broadcast arguments for a while now, I uncommented the test. The only modification I made to the code was fixing a small error, where the author incorrectly tried to call "assert_equal" as a method of the TestSingleEltArrayInput instead of a function that was imported from numpy.testing. I ran runtests.py, and the new test passed.
Diffstat (limited to 'numpy/random')
| -rw-r--r-- | numpy/random/tests/test_random.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index 773b63653..ecda5dab5 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -1712,23 +1712,22 @@ class TestSingleEltArrayInput: out = func(self.argOne, argTwo[0]) assert_equal(out.shape, self.tgtShape) -# TODO: Uncomment once randint can broadcast arguments -# def test_randint(self): -# itype = [bool, np.int8, np.uint8, np.int16, np.uint16, -# np.int32, np.uint32, np.int64, np.uint64] -# func = np.random.randint -# high = np.array([1]) -# low = np.array([0]) -# -# for dt in itype: -# out = func(low, high, dtype=dt) -# self.assert_equal(out.shape, self.tgtShape) -# -# out = func(low[0], high, dtype=dt) -# self.assert_equal(out.shape, self.tgtShape) -# -# out = func(low, high[0], dtype=dt) -# self.assert_equal(out.shape, self.tgtShape) + def test_randint(self): + itype = [bool, np.int8, np.uint8, np.int16, np.uint16, + np.int32, np.uint32, np.int64, np.uint64] + func = np.random.randint + high = np.array([1]) + low = np.array([0]) + + for dt in itype: + out = func(low, high, dtype=dt) + assert_equal(out.shape, self.tgtShape) + + out = func(low[0], high, dtype=dt) + assert_equal(out.shape, self.tgtShape) + + out = func(low, high[0], dtype=dt) + assert_equal(out.shape, self.tgtShape) def test_three_arg_funcs(self): funcs = [np.random.noncentral_f, np.random.triangular, |
