diff options
author | Ralf Gommers <ralf.gommers@gmail.com> | 2021-01-23 14:16:17 +0100 |
---|---|---|
committer | Ralf Gommers <ralf.gommers@gmail.com> | 2021-01-23 17:01:57 +0100 |
commit | 768acb7cf856bb49d4d183f8e9cbd456ecc32475 (patch) | |
tree | 0585eab99f08fc5ffd00f1606890a630bec7ac9f | |
parent | 164959a9049c05901dd4a2cc66c07a4df37527c1 (diff) | |
download | numpy-768acb7cf856bb49d4d183f8e9cbd456ecc32475.tar.gz |
BUG: shuffling empty array with axis=1 was broken
This would trigger:
```
NotImplementedError: Axis argument is only supported on ndarray objects
```
because empty arrays and array scalars would take the "untyped" path.
The bug exists only for Generator, not in RandomState (it doesn't have
axis keyword for `shuffle`), but update both because it keeps
implementations in sync and the change results in more understandable
code also for RandomState.
-rw-r--r-- | numpy/random/_generator.pyx | 15 | ||||
-rw-r--r-- | numpy/random/mtrand.pyx | 11 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 15 | ||||
-rw-r--r-- | numpy/random/tests/test_randomstate.py | 6 |
4 files changed, 39 insertions, 8 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index 5a7d4a21a..ba1713dfa 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -4415,7 +4415,15 @@ cdef class Generator: with self.lock, nogil: _shuffle_raw_wrap(&self._bitgen, n, 1, itemsize, stride, x_ptr, buf_ptr) - elif isinstance(x, np.ndarray) and x.ndim and x.size: + elif isinstance(x, np.ndarray): + if axis >= x.ndim: + raise np.AxisError(f"Cannot shuffle along axis {axis} for " + f"array of dimension {x.ndim}") + + if x.size == 0: + # shuffling is a no-op + return + x = np.swapaxes(x, 0, axis) buf = np.empty_like(x[0, ...]) with self.lock: @@ -4429,12 +4437,13 @@ cdef class Generator: x[i] = buf else: # Untyped path. - if not isinstance(x, (np.ndarray, MutableSequence)): + if not isinstance(x, MutableSequence): # See gh-18206. We may decide to deprecate here in the future. warnings.warn( "`x` isn't a recognized object; `shuffle` is not guaranteed " "to behave correctly. E.g., non-numpy array/tensor objects " - "with view semantics may contain duplicates after shuffling." + "with view semantics may contain duplicates after shuffling.", + UserWarning, stacklevel=2 ) if axis != 0: diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index 9ee3d9ff3..814630c03 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -4457,7 +4457,11 @@ cdef class RandomState: self._shuffle_raw(n, sizeof(np.npy_intp), stride, x_ptr, buf_ptr) else: self._shuffle_raw(n, itemsize, stride, x_ptr, buf_ptr) - elif isinstance(x, np.ndarray) and x.ndim and x.size: + elif isinstance(x, np.ndarray): + if x.size == 0: + # shuffling is a no-op + return + buf = np.empty_like(x[0, ...]) with self.lock: for i in reversed(range(1, n)): @@ -4469,12 +4473,13 @@ cdef class RandomState: x[i] = buf else: # Untyped path. - if not isinstance(x, (np.ndarray, MutableSequence)): + if not isinstance(x, MutableSequence): # See gh-18206. We may decide to deprecate here in the future. warnings.warn( "`x` isn't a recognized object; `shuffle` is not guaranteed " "to behave correctly. E.g., non-numpy array/tensor objects " - "with view semantics may contain duplicates after shuffling." + "with view semantics may contain duplicates after shuffling.", + UserWarning, stacklevel=2 ) with self.lock: diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index c4fb5883c..47c81584c 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -960,6 +960,14 @@ class TestRandomDist: random.shuffle(actual, axis=-1) assert_array_equal(actual, desired) + def test_shuffle_custom_axis_empty(self): + random = Generator(MT19937(self.seed)) + desired = np.array([]).reshape((0, 6)) + for axis in (0, 1): + actual = np.array([]).reshape((0, 6)) + random.shuffle(actual, axis=axis) + assert_array_equal(actual, desired) + def test_shuffle_axis_nonsquare(self): y1 = np.arange(20).reshape(2, 10) y2 = y1.copy() @@ -993,6 +1001,11 @@ class TestRandomDist: arr = [[1, 2, 3], [4, 5, 6]] assert_raises(NotImplementedError, random.shuffle, arr, 1) + arr = np.array(3) + assert_raises(TypeError, random.shuffle, arr) + arr = np.ones((3, 2)) + assert_raises(np.AxisError, random.shuffle, arr, 2) + def test_permutation(self): random = Generator(MT19937(self.seed)) alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] @@ -1004,7 +1017,7 @@ class TestRandomDist: arr_2d = np.atleast_2d([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]).T actual = random.permutation(arr_2d) assert_array_equal(actual, np.atleast_2d(desired).T) - + bad_x_str = "abcd" assert_raises(np.AxisError, random.permutation, bad_x_str) diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py index b70a04347..7f5f08050 100644 --- a/numpy/random/tests/test_randomstate.py +++ b/numpy/random/tests/test_randomstate.py @@ -642,7 +642,7 @@ class TestRandomDist: a = np.array([42, 1, 2]) p = [None, None, None] assert_raises(ValueError, random.choice, a, p=p) - + def test_choice_p_non_contiguous(self): p = np.ones(10) / 5 p[1::2] = 3.0 @@ -699,6 +699,10 @@ class TestRandomDist: assert_equal( sorted(b.data[~b.mask]), sorted(b_orig.data[~b_orig.mask])) + def test_shuffle_invalid_objects(self): + x = np.array(3) + assert_raises(TypeError, random.shuffle, x) + def test_permutation(self): random.seed(self.seed) alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] |