summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Sheppard <kevin.sheppard@gmail.com>2021-02-26 23:06:56 +0000
committerKevin Sheppard <kevin.sheppard@gmail.com>2021-02-26 23:44:50 +0000
commitb1015adfdbce55b7ee9211baca2f51284d67694a (patch)
tree7b9a745abd5912105c95b9dcb7bb099c5a4951eb
parente900be2bf367d2a90af922e56c27aadcd0581bdf (diff)
downloadnumpy-b1015adfdbce55b7ee9211baca2f51284d67694a.tar.gz
Port error to RandomState
-rw-r--r--numpy/random/mtrand.pyx15
-rw-r--r--numpy/random/tests/test_generator_mt19937.py2
-rw-r--r--numpy/random/tests/test_randomstate.py8
3 files changed, 23 insertions, 2 deletions
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 6f44e271f..4e12f8e59 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -4232,7 +4232,20 @@ cdef class RandomState:
pix = <double*>np.PyArray_DATA(parr)
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
- raise ValueError("sum(pvals[:-1]) > 1.0")
+ # When floating, but not float dtype, and close, improve the error
+ # 1.0001 works for float16 and float32
+ if (isinstance(pvals, np.ndarray)
+ and np.issubdtype(pvals.dtype, np.floating)
+ and pvals.dtype != float
+ and pvals.sum() < 1.0001):
+ msg = ("sum(pvals[:-1].astype(np.float64)) > 1.0. The pvals "
+ "array is cast to 64-bit floating point prior to "
+ "checking the sum. Precision changes when casting may "
+ "cause problems even if the sum of the original pvals "
+ "is valid.")
+ else:
+ msg = "sum(pvals[:-1]) > 1.0"
+ raise ValueError(msg)
if size is None:
shape = (d,)
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 9de044774..446b350dd 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -147,7 +147,7 @@ class TestMultinomial:
1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32)
pvals = x / x.sum()
random = Generator(MT19937(1432985819))
- match = r"[\w\s]*pvals are cast to 64-bit floating"
+ match = r"[\w\s]*pvals array is cast to 64-bit floating"
with pytest.raises(ValueError, match=match):
random.multinomial(1, pvals)
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index b16275b70..861813a95 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -167,6 +167,14 @@ class TestMultinomial:
contig = random.multinomial(100, pvals=np.ascontiguousarray(pvals))
assert_array_equal(non_contig, contig)
+ def test_multinomial_pvals_float32(self):
+ x = np.array([9.9e-01, 9.9e-01, 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09,
+ 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32)
+ pvals = x / x.sum()
+ match = r"[\w\s]*pvals array is cast to 64-bit floating"
+ with pytest.raises(ValueError, match=match):
+ random.multinomial(1, pvals)
+
class TestSetState:
def setup(self):