diff options
| author | Charles Harris <charlesr.harris@gmail.com> | 2020-06-07 19:26:31 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-06-07 19:26:31 -0600 |
| commit | 9e2d66a110090ee7416f4686593fdf9142ce6cbc (patch) | |
| tree | 303738c685de3607895b890d6cee1683d28a5497 | |
| parent | a60d5e8d7652641b861587f1fc8efa3cf2081bb5 (diff) | |
| parent | 776189a168f12c8da385a48ab1c381133908fa73 (diff) | |
| download | numpy-9e2d66a110090ee7416f4686593fdf9142ce6cbc.tar.gz | |
Merge pull request #16503 from bashtage/bug-broadcast-size
BUG:random: Error when ``size`` is smaller than broadcast input shapes.
| -rw-r--r-- | numpy/random/_common.pyx | 29 | ||||
| -rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 41 |
2 files changed, 70 insertions, 0 deletions
diff --git a/numpy/random/_common.pyx b/numpy/random/_common.pyx index ef1afac7c..fd5f8addc 100644 --- a/numpy/random/_common.pyx +++ b/numpy/random/_common.pyx @@ -218,6 +218,19 @@ cdef np.ndarray int_to_array(object value, object name, object bits, object uint return out +cdef validate_output_shape(iter_shape, np.ndarray output): + cdef np.npy_intp *shape, ndim, i + cdef bint error + dims = np.PyArray_DIMS(output) + ndim = np.PyArray_NDIM(output) + output_shape = tuple((dims[i] for i in range(ndim))) + if iter_shape != output_shape: + raise ValueError( + f"Output size {output_shape} is not compatible with broadcast " + f"dimensions of inputs {iter_shape}." + ) + + cdef check_output(object out, object dtype, object size): if out is None: return @@ -404,6 +417,7 @@ cdef object cont_broadcast_1(void *func, void *state, object size, object lock, randoms_data = <double *>np.PyArray_DATA(randoms) n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew2(randoms, a_arr) + validate_output_shape(it.shape, randoms) with lock, nogil: for i in range(n): @@ -441,6 +455,8 @@ cdef object cont_broadcast_2(void *func, void *state, object size, object lock, n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew3(randoms, a_arr, b_arr) + validate_output_shape(it.shape, randoms) + with lock, nogil: for i in range(n): a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0] @@ -482,6 +498,8 @@ cdef object cont_broadcast_3(void *func, void *state, object size, object lock, n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew4(randoms, a_arr, b_arr, c_arr) + validate_output_shape(it.shape, randoms) + with lock, nogil: for i in range(n): a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0] @@ -611,6 +629,8 @@ cdef object discrete_broadcast_d(void *func, void *state, object size, object lo n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew2(randoms, a_arr) + validate_output_shape(it.shape, randoms) + with lock, nogil: for i in range(n): a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0] @@ -645,6 +665,8 @@ cdef object discrete_broadcast_dd(void *func, void *state, object size, object l n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew3(randoms, a_arr, b_arr) + validate_output_shape(it.shape, randoms) + with lock, nogil: for i in range(n): a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0] @@ -680,6 +702,8 @@ cdef object discrete_broadcast_di(void *func, void *state, object size, object l n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew3(randoms, a_arr, b_arr) + validate_output_shape(it.shape, randoms) + with lock, nogil: for i in range(n): a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0] @@ -719,6 +743,8 @@ cdef object discrete_broadcast_iii(void *func, void *state, object size, object n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew4(randoms, a_arr, b_arr, c_arr) + validate_output_shape(it.shape, randoms) + with lock, nogil: for i in range(n): a_val = (<int64_t*>np.PyArray_MultiIter_DATA(it, 1))[0] @@ -750,6 +776,8 @@ cdef object discrete_broadcast_i(void *func, void *state, object size, object lo n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew2(randoms, a_arr) + validate_output_shape(it.shape, randoms) + with lock, nogil: for i in range(n): a_val = (<int64_t*>np.PyArray_MultiIter_DATA(it, 1))[0] @@ -923,6 +951,7 @@ cdef object cont_broadcast_1_f(void *func, bitgen_t *state, object size, object randoms_data = <float *>np.PyArray_DATA(randoms) n = np.PyArray_SIZE(randoms) it = np.PyArray_MultiIterNew2(randoms, a_arr) + validate_output_shape(it.shape, randoms) with lock, nogil: for i in range(n): diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index f72b748ba..332b63198 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -2397,3 +2397,44 @@ def test_jumped(config): md5 = hashlib.md5(key) assert jumped.state["state"]["pos"] == config["jumped"]["pos"] assert md5.hexdigest() == config["jumped"]["key_md5"] + + +def test_broadcast_size_error(): + mu = np.ones(3) + sigma = np.ones((4, 3)) + size = (10, 4, 2) + assert random.normal(mu, sigma, size=(5, 4, 3)).shape == (5, 4, 3) + with pytest.raises(ValueError): + random.normal(mu, sigma, size=size) + with pytest.raises(ValueError): + random.normal(mu, sigma, size=(1, 3)) + with pytest.raises(ValueError): + random.normal(mu, sigma, size=(4, 1, 1)) + # 1 arg + shape = np.ones((4, 3)) + with pytest.raises(ValueError): + random.standard_gamma(shape, size=size) + with pytest.raises(ValueError): + random.standard_gamma(shape, size=(3,)) + with pytest.raises(ValueError): + random.standard_gamma(shape, size=3) + # Check out + out = np.empty(size) + with pytest.raises(ValueError): + random.standard_gamma(shape, out=out) + + # 3 arg + a = random.chisquare(5, size=3) + b = random.chisquare(5, size=(4, 3)) + c = random.chisquare(5, size=(5, 4, 3)) + assert random.noncentral_f(a, b, c).shape == (5, 4, 3) + with pytest.raises(ValueError, match=r"Output size \(6, 5, 1, 1\) is"): + random.noncentral_f(a, b, c, size=(6, 5, 1, 1)) + + +def test_broadcast_size_scalar(): + mu = np.ones(3) + sigma = np.ones(3) + random.normal(mu, sigma, size=3) + with pytest.raises(ValueError): + random.normal(mu, sigma, size=2) |
