summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2020-06-07 19:26:31 -0600
committerGitHub <noreply@github.com>2020-06-07 19:26:31 -0600
commit9e2d66a110090ee7416f4686593fdf9142ce6cbc (patch)
tree303738c685de3607895b890d6cee1683d28a5497
parenta60d5e8d7652641b861587f1fc8efa3cf2081bb5 (diff)
parent776189a168f12c8da385a48ab1c381133908fa73 (diff)
downloadnumpy-9e2d66a110090ee7416f4686593fdf9142ce6cbc.tar.gz
Merge pull request #16503 from bashtage/bug-broadcast-size
BUG:random: Error when ``size`` is smaller than broadcast input shapes.
-rw-r--r--numpy/random/_common.pyx29
-rw-r--r--numpy/random/tests/test_generator_mt19937.py41
2 files changed, 70 insertions, 0 deletions
diff --git a/numpy/random/_common.pyx b/numpy/random/_common.pyx
index ef1afac7c..fd5f8addc 100644
--- a/numpy/random/_common.pyx
+++ b/numpy/random/_common.pyx
@@ -218,6 +218,19 @@ cdef np.ndarray int_to_array(object value, object name, object bits, object uint
return out
+cdef validate_output_shape(iter_shape, np.ndarray output):
+ cdef np.npy_intp *shape, ndim, i
+ cdef bint error
+ dims = np.PyArray_DIMS(output)
+ ndim = np.PyArray_NDIM(output)
+ output_shape = tuple((dims[i] for i in range(ndim)))
+ if iter_shape != output_shape:
+ raise ValueError(
+ f"Output size {output_shape} is not compatible with broadcast "
+ f"dimensions of inputs {iter_shape}."
+ )
+
+
cdef check_output(object out, object dtype, object size):
if out is None:
return
@@ -404,6 +417,7 @@ cdef object cont_broadcast_1(void *func, void *state, object size, object lock,
randoms_data = <double *>np.PyArray_DATA(randoms)
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew2(randoms, a_arr)
+ validate_output_shape(it.shape, randoms)
with lock, nogil:
for i in range(n):
@@ -441,6 +455,8 @@ cdef object cont_broadcast_2(void *func, void *state, object size, object lock,
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew3(randoms, a_arr, b_arr)
+ validate_output_shape(it.shape, randoms)
+
with lock, nogil:
for i in range(n):
a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
@@ -482,6 +498,8 @@ cdef object cont_broadcast_3(void *func, void *state, object size, object lock,
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew4(randoms, a_arr, b_arr, c_arr)
+ validate_output_shape(it.shape, randoms)
+
with lock, nogil:
for i in range(n):
a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
@@ -611,6 +629,8 @@ cdef object discrete_broadcast_d(void *func, void *state, object size, object lo
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew2(randoms, a_arr)
+ validate_output_shape(it.shape, randoms)
+
with lock, nogil:
for i in range(n):
a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
@@ -645,6 +665,8 @@ cdef object discrete_broadcast_dd(void *func, void *state, object size, object l
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew3(randoms, a_arr, b_arr)
+ validate_output_shape(it.shape, randoms)
+
with lock, nogil:
for i in range(n):
a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
@@ -680,6 +702,8 @@ cdef object discrete_broadcast_di(void *func, void *state, object size, object l
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew3(randoms, a_arr, b_arr)
+ validate_output_shape(it.shape, randoms)
+
with lock, nogil:
for i in range(n):
a_val = (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
@@ -719,6 +743,8 @@ cdef object discrete_broadcast_iii(void *func, void *state, object size, object
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew4(randoms, a_arr, b_arr, c_arr)
+ validate_output_shape(it.shape, randoms)
+
with lock, nogil:
for i in range(n):
a_val = (<int64_t*>np.PyArray_MultiIter_DATA(it, 1))[0]
@@ -750,6 +776,8 @@ cdef object discrete_broadcast_i(void *func, void *state, object size, object lo
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew2(randoms, a_arr)
+ validate_output_shape(it.shape, randoms)
+
with lock, nogil:
for i in range(n):
a_val = (<int64_t*>np.PyArray_MultiIter_DATA(it, 1))[0]
@@ -923,6 +951,7 @@ cdef object cont_broadcast_1_f(void *func, bitgen_t *state, object size, object
randoms_data = <float *>np.PyArray_DATA(randoms)
n = np.PyArray_SIZE(randoms)
it = np.PyArray_MultiIterNew2(randoms, a_arr)
+ validate_output_shape(it.shape, randoms)
with lock, nogil:
for i in range(n):
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index f72b748ba..332b63198 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -2397,3 +2397,44 @@ def test_jumped(config):
md5 = hashlib.md5(key)
assert jumped.state["state"]["pos"] == config["jumped"]["pos"]
assert md5.hexdigest() == config["jumped"]["key_md5"]
+
+
+def test_broadcast_size_error():
+ mu = np.ones(3)
+ sigma = np.ones((4, 3))
+ size = (10, 4, 2)
+ assert random.normal(mu, sigma, size=(5, 4, 3)).shape == (5, 4, 3)
+ with pytest.raises(ValueError):
+ random.normal(mu, sigma, size=size)
+ with pytest.raises(ValueError):
+ random.normal(mu, sigma, size=(1, 3))
+ with pytest.raises(ValueError):
+ random.normal(mu, sigma, size=(4, 1, 1))
+ # 1 arg
+ shape = np.ones((4, 3))
+ with pytest.raises(ValueError):
+ random.standard_gamma(shape, size=size)
+ with pytest.raises(ValueError):
+ random.standard_gamma(shape, size=(3,))
+ with pytest.raises(ValueError):
+ random.standard_gamma(shape, size=3)
+ # Check out
+ out = np.empty(size)
+ with pytest.raises(ValueError):
+ random.standard_gamma(shape, out=out)
+
+ # 3 arg
+ a = random.chisquare(5, size=3)
+ b = random.chisquare(5, size=(4, 3))
+ c = random.chisquare(5, size=(5, 4, 3))
+ assert random.noncentral_f(a, b, c).shape == (5, 4, 3)
+ with pytest.raises(ValueError, match=r"Output size \(6, 5, 1, 1\) is"):
+ random.noncentral_f(a, b, c, size=(6, 5, 1, 1))
+
+
+def test_broadcast_size_scalar():
+ mu = np.ones(3)
+ sigma = np.ones(3)
+ random.normal(mu, sigma, size=3)
+ with pytest.raises(ValueError):
+ random.normal(mu, sigma, size=2)