diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-04-07 06:30:40 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-07 06:30:40 -0600 |
commit | 4f799383c22dea3ae4a7f5a7123df774281181b4 (patch) | |
tree | 67d7cf48d15cf3bd6a7d9548c5f6c7501e16f90f | |
parent | 914407d51b878bf7bf34dbd8dd72cc2dbc428673 (diff) | |
parent | d2cbd17cba2132e8c02da394be5c3f93ddd76919 (diff) | |
download | numpy-4f799383c22dea3ae4a7f5a7123df774281181b4.tar.gz |
Merge pull request #18731 from bashtage/fix-out-req
BUG: Check out requirements and raise when not satisfied
-rw-r--r-- | numpy/random/_common.pyx | 39 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 23 |
2 files changed, 53 insertions, 9 deletions
diff --git a/numpy/random/_common.pyx b/numpy/random/_common.pyx index 719647c3e..c397180fb 100644 --- a/numpy/random/_common.pyx +++ b/numpy/random/_common.pyx @@ -232,13 +232,34 @@ cdef validate_output_shape(iter_shape, np.ndarray output): ) -cdef check_output(object out, object dtype, object size): +cdef check_output(object out, object dtype, object size, bint require_c_array): + """ + Check user-supplied output array properties and shape + + Parameters + ---------- + out : {ndarray, None} + The array to check. If None, returns immediately. + dtype : dtype + The required dtype of out. + size : {None, int, tuple[int]} + The size passed. If out is an ndarray, verifies that the shape of out + matches size. + require_c_array : bool + Whether out must be a C-array. If False, out can be either C- or F- + ordered. If True, must be C-ordered. In either case, must be + contiguous, writable, aligned and in native byte-order. + """ if out is None: return cdef np.ndarray out_array = <np.ndarray>out - if not (np.PyArray_CHKFLAGS(out_array, np.NPY_CARRAY) or - np.PyArray_CHKFLAGS(out_array, np.NPY_FARRAY)): - raise ValueError('Supplied output array is not contiguous, writable or aligned.') + if not (np.PyArray_ISCARRAY(out_array) or + (np.PyArray_ISFARRAY(out_array) and not require_c_array)): + req = "C-" if require_c_array else "" + raise ValueError( + f'Supplied output array must be {req}contiguous, writable, ' + f'aligned, and in machine byte-order.' + ) if out_array.dtype != dtype: raise TypeError('Supplied output array has the wrong type. ' 'Expected {0}, got {1}'.format(np.dtype(dtype), out_array.dtype)) @@ -264,7 +285,7 @@ cdef object double_fill(void *func, bitgen_t *state, object size, object lock, o return out_val if out is not None: - check_output(out, np.float64, size) + check_output(out, np.float64, size, False) out_array = <np.ndarray>out else: out_array = <np.ndarray>np.empty(size, np.double) @@ -288,7 +309,7 @@ cdef object float_fill(void *func, bitgen_t *state, object size, object lock, ob return out_val if out is not None: - check_output(out, np.float32, size) + check_output(out, np.float32, size, False) out_array = <np.ndarray>out else: out_array = <np.ndarray>np.empty(size, np.float32) @@ -310,7 +331,7 @@ cdef object float_fill_from_double(void *func, bitgen_t *state, object size, obj return <float>random_func(state) if out is not None: - check_output(out, np.float32, size) + check_output(out, np.float32, size, False) out_array = <np.ndarray>out else: out_array = <np.ndarray>np.empty(size, np.float32) @@ -521,7 +542,7 @@ cdef object cont(void *func, void *state, object size, object lock, int narg, cdef np.ndarray a_arr, b_arr, c_arr cdef double _a = 0.0, _b = 0.0, _c = 0.0 cdef bint is_scalar = True - check_output(out, np.float64, size) + check_output(out, np.float64, size, narg > 0) if narg > 0: a_arr = <np.ndarray>np.PyArray_FROM_OTF(a, np.NPY_DOUBLE, np.NPY_ALIGNED) is_scalar = is_scalar and np.PyArray_NDIM(a_arr) == 0 @@ -971,7 +992,7 @@ cdef object cont_f(void *func, bitgen_t *state, object size, object lock, cdef float _a cdef bint is_scalar = True cdef int requirements = np.NPY_ALIGNED | np.NPY_FORCECAST - check_output(out, np.float32, size) + check_output(out, np.float32, size, True) a_arr = <np.ndarray>np.PyArray_FROMANY(a, np.NPY_FLOAT32, 0, 0, requirements) is_scalar = np.PyArray_NDIM(a_arr) == 0 diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 0108d84b3..4abcf6fe4 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -2581,3 +2581,26 @@ def test_single_arg_integer_exception(high, endpoint): gen.integers(-1, high, endpoint=endpoint) with pytest.raises(ValueError, match=msg): gen.integers([-1], high, endpoint=endpoint) + + +@pytest.mark.parametrize("dtype", ["f4", "f8"]) +def test_c_contig_req_out(dtype): + # GH 18704 + out = np.empty((2, 3), order="F", dtype=dtype) + shape = [1, 2, 3] + with pytest.raises(ValueError, match="Supplied output array"): + random.standard_gamma(shape, out=out, dtype=dtype) + with pytest.raises(ValueError, match="Supplied output array"): + random.standard_gamma(shape, out=out, size=out.shape, dtype=dtype) + + +@pytest.mark.parametrize("dtype", ["f4", "f8"]) +@pytest.mark.parametrize("order", ["F", "C"]) +@pytest.mark.parametrize("dist", [random.standard_normal, random.random]) +def test_contig_req_out(dist, order, dtype): + # GH 18704 + out = np.empty((2, 3), dtype=dtype, order=order) + variates = dist(out=out, dtype=dtype) + assert variates is out + variates = dist(out=out, dtype=dtype, size=out.shape) + assert variates is out |