summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-04-07 06:30:40 -0600
committerGitHub <noreply@github.com>2021-04-07 06:30:40 -0600
commit4f799383c22dea3ae4a7f5a7123df774281181b4 (patch)
tree67d7cf48d15cf3bd6a7d9548c5f6c7501e16f90f
parent914407d51b878bf7bf34dbd8dd72cc2dbc428673 (diff)
parentd2cbd17cba2132e8c02da394be5c3f93ddd76919 (diff)
downloadnumpy-4f799383c22dea3ae4a7f5a7123df774281181b4.tar.gz
Merge pull request #18731 from bashtage/fix-out-req
BUG: Check out requirements and raise when not satisfied
-rw-r--r--numpy/random/_common.pyx39
-rw-r--r--numpy/random/tests/test_generator_mt19937.py23
2 files changed, 53 insertions, 9 deletions
diff --git a/numpy/random/_common.pyx b/numpy/random/_common.pyx
index 719647c3e..c397180fb 100644
--- a/numpy/random/_common.pyx
+++ b/numpy/random/_common.pyx
@@ -232,13 +232,34 @@ cdef validate_output_shape(iter_shape, np.ndarray output):
)
-cdef check_output(object out, object dtype, object size):
+cdef check_output(object out, object dtype, object size, bint require_c_array):
+ """
+ Check user-supplied output array properties and shape
+
+ Parameters
+ ----------
+ out : {ndarray, None}
+ The array to check. If None, returns immediately.
+ dtype : dtype
+ The required dtype of out.
+ size : {None, int, tuple[int]}
+ The size passed. If out is an ndarray, verifies that the shape of out
+ matches size.
+ require_c_array : bool
+ Whether out must be a C-array. If False, out can be either C- or F-
+ ordered. If True, must be C-ordered. In either case, must be
+ contiguous, writable, aligned and in native byte-order.
+ """
if out is None:
return
cdef np.ndarray out_array = <np.ndarray>out
- if not (np.PyArray_CHKFLAGS(out_array, np.NPY_CARRAY) or
- np.PyArray_CHKFLAGS(out_array, np.NPY_FARRAY)):
- raise ValueError('Supplied output array is not contiguous, writable or aligned.')
+ if not (np.PyArray_ISCARRAY(out_array) or
+ (np.PyArray_ISFARRAY(out_array) and not require_c_array)):
+ req = "C-" if require_c_array else ""
+ raise ValueError(
+ f'Supplied output array must be {req}contiguous, writable, '
+ f'aligned, and in machine byte-order.'
+ )
if out_array.dtype != dtype:
raise TypeError('Supplied output array has the wrong type. '
'Expected {0}, got {1}'.format(np.dtype(dtype), out_array.dtype))
@@ -264,7 +285,7 @@ cdef object double_fill(void *func, bitgen_t *state, object size, object lock, o
return out_val
if out is not None:
- check_output(out, np.float64, size)
+ check_output(out, np.float64, size, False)
out_array = <np.ndarray>out
else:
out_array = <np.ndarray>np.empty(size, np.double)
@@ -288,7 +309,7 @@ cdef object float_fill(void *func, bitgen_t *state, object size, object lock, ob
return out_val
if out is not None:
- check_output(out, np.float32, size)
+ check_output(out, np.float32, size, False)
out_array = <np.ndarray>out
else:
out_array = <np.ndarray>np.empty(size, np.float32)
@@ -310,7 +331,7 @@ cdef object float_fill_from_double(void *func, bitgen_t *state, object size, obj
return <float>random_func(state)
if out is not None:
- check_output(out, np.float32, size)
+ check_output(out, np.float32, size, False)
out_array = <np.ndarray>out
else:
out_array = <np.ndarray>np.empty(size, np.float32)
@@ -521,7 +542,7 @@ cdef object cont(void *func, void *state, object size, object lock, int narg,
cdef np.ndarray a_arr, b_arr, c_arr
cdef double _a = 0.0, _b = 0.0, _c = 0.0
cdef bint is_scalar = True
- check_output(out, np.float64, size)
+ check_output(out, np.float64, size, narg > 0)
if narg > 0:
a_arr = <np.ndarray>np.PyArray_FROM_OTF(a, np.NPY_DOUBLE, np.NPY_ALIGNED)
is_scalar = is_scalar and np.PyArray_NDIM(a_arr) == 0
@@ -971,7 +992,7 @@ cdef object cont_f(void *func, bitgen_t *state, object size, object lock,
cdef float _a
cdef bint is_scalar = True
cdef int requirements = np.NPY_ALIGNED | np.NPY_FORCECAST
- check_output(out, np.float32, size)
+ check_output(out, np.float32, size, True)
a_arr = <np.ndarray>np.PyArray_FROMANY(a, np.NPY_FLOAT32, 0, 0, requirements)
is_scalar = np.PyArray_NDIM(a_arr) == 0
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 0108d84b3..4abcf6fe4 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -2581,3 +2581,26 @@ def test_single_arg_integer_exception(high, endpoint):
gen.integers(-1, high, endpoint=endpoint)
with pytest.raises(ValueError, match=msg):
gen.integers([-1], high, endpoint=endpoint)
+
+
+@pytest.mark.parametrize("dtype", ["f4", "f8"])
+def test_c_contig_req_out(dtype):
+ # GH 18704
+ out = np.empty((2, 3), order="F", dtype=dtype)
+ shape = [1, 2, 3]
+ with pytest.raises(ValueError, match="Supplied output array"):
+ random.standard_gamma(shape, out=out, dtype=dtype)
+ with pytest.raises(ValueError, match="Supplied output array"):
+ random.standard_gamma(shape, out=out, size=out.shape, dtype=dtype)
+
+
+@pytest.mark.parametrize("dtype", ["f4", "f8"])
+@pytest.mark.parametrize("order", ["F", "C"])
+@pytest.mark.parametrize("dist", [random.standard_normal, random.random])
+def test_contig_req_out(dist, order, dtype):
+ # GH 18704
+ out = np.empty((2, 3), dtype=dtype, order=order)
+ variates = dist(out=out, dtype=dtype)
+ assert variates is out
+ variates = dist(out=out, dtype=dtype, size=out.shape)
+ assert variates is out