summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@gmail.com>2021-01-23 14:16:17 +0100
committerRalf Gommers <ralf.gommers@gmail.com>2021-01-23 17:01:57 +0100
commit768acb7cf856bb49d4d183f8e9cbd456ecc32475 (patch)
tree0585eab99f08fc5ffd00f1606890a630bec7ac9f
parent164959a9049c05901dd4a2cc66c07a4df37527c1 (diff)
downloadnumpy-768acb7cf856bb49d4d183f8e9cbd456ecc32475.tar.gz
BUG: shuffling empty array with axis=1 was broken
This would trigger: ``` NotImplementedError: Axis argument is only supported on ndarray objects ``` because empty arrays and array scalars would take the "untyped" path. The bug exists only for Generator, not in RandomState (it doesn't have axis keyword for `shuffle`), but update both because it keeps implementations in sync and the change results in more understandable code also for RandomState.
-rw-r--r--numpy/random/_generator.pyx15
-rw-r--r--numpy/random/mtrand.pyx11
-rw-r--r--numpy/random/tests/test_generator_mt19937.py15
-rw-r--r--numpy/random/tests/test_randomstate.py6
4 files changed, 39 insertions, 8 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index 5a7d4a21a..ba1713dfa 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -4415,7 +4415,15 @@ cdef class Generator:
with self.lock, nogil:
_shuffle_raw_wrap(&self._bitgen, n, 1, itemsize, stride,
x_ptr, buf_ptr)
- elif isinstance(x, np.ndarray) and x.ndim and x.size:
+ elif isinstance(x, np.ndarray):
+ if axis >= x.ndim:
+ raise np.AxisError(f"Cannot shuffle along axis {axis} for "
+ f"array of dimension {x.ndim}")
+
+ if x.size == 0:
+ # shuffling is a no-op
+ return
+
x = np.swapaxes(x, 0, axis)
buf = np.empty_like(x[0, ...])
with self.lock:
@@ -4429,12 +4437,13 @@ cdef class Generator:
x[i] = buf
else:
# Untyped path.
- if not isinstance(x, (np.ndarray, MutableSequence)):
+ if not isinstance(x, MutableSequence):
# See gh-18206. We may decide to deprecate here in the future.
warnings.warn(
"`x` isn't a recognized object; `shuffle` is not guaranteed "
"to behave correctly. E.g., non-numpy array/tensor objects "
- "with view semantics may contain duplicates after shuffling."
+ "with view semantics may contain duplicates after shuffling.",
+ UserWarning, stacklevel=2
)
if axis != 0:
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 9ee3d9ff3..814630c03 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -4457,7 +4457,11 @@ cdef class RandomState:
self._shuffle_raw(n, sizeof(np.npy_intp), stride, x_ptr, buf_ptr)
else:
self._shuffle_raw(n, itemsize, stride, x_ptr, buf_ptr)
- elif isinstance(x, np.ndarray) and x.ndim and x.size:
+ elif isinstance(x, np.ndarray):
+ if x.size == 0:
+ # shuffling is a no-op
+ return
+
buf = np.empty_like(x[0, ...])
with self.lock:
for i in reversed(range(1, n)):
@@ -4469,12 +4473,13 @@ cdef class RandomState:
x[i] = buf
else:
# Untyped path.
- if not isinstance(x, (np.ndarray, MutableSequence)):
+ if not isinstance(x, MutableSequence):
# See gh-18206. We may decide to deprecate here in the future.
warnings.warn(
"`x` isn't a recognized object; `shuffle` is not guaranteed "
"to behave correctly. E.g., non-numpy array/tensor objects "
- "with view semantics may contain duplicates after shuffling."
+ "with view semantics may contain duplicates after shuffling.",
+ UserWarning, stacklevel=2
)
with self.lock:
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index c4fb5883c..47c81584c 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -960,6 +960,14 @@ class TestRandomDist:
random.shuffle(actual, axis=-1)
assert_array_equal(actual, desired)
+ def test_shuffle_custom_axis_empty(self):
+ random = Generator(MT19937(self.seed))
+ desired = np.array([]).reshape((0, 6))
+ for axis in (0, 1):
+ actual = np.array([]).reshape((0, 6))
+ random.shuffle(actual, axis=axis)
+ assert_array_equal(actual, desired)
+
def test_shuffle_axis_nonsquare(self):
y1 = np.arange(20).reshape(2, 10)
y2 = y1.copy()
@@ -993,6 +1001,11 @@ class TestRandomDist:
arr = [[1, 2, 3], [4, 5, 6]]
assert_raises(NotImplementedError, random.shuffle, arr, 1)
+ arr = np.array(3)
+ assert_raises(TypeError, random.shuffle, arr)
+ arr = np.ones((3, 2))
+ assert_raises(np.AxisError, random.shuffle, arr, 2)
+
def test_permutation(self):
random = Generator(MT19937(self.seed))
alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
@@ -1004,7 +1017,7 @@ class TestRandomDist:
arr_2d = np.atleast_2d([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]).T
actual = random.permutation(arr_2d)
assert_array_equal(actual, np.atleast_2d(desired).T)
-
+
bad_x_str = "abcd"
assert_raises(np.AxisError, random.permutation, bad_x_str)
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index b70a04347..7f5f08050 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -642,7 +642,7 @@ class TestRandomDist:
a = np.array([42, 1, 2])
p = [None, None, None]
assert_raises(ValueError, random.choice, a, p=p)
-
+
def test_choice_p_non_contiguous(self):
p = np.ones(10) / 5
p[1::2] = 3.0
@@ -699,6 +699,10 @@ class TestRandomDist:
assert_equal(
sorted(b.data[~b.mask]), sorted(b_orig.data[~b_orig.mask]))
+ def test_shuffle_invalid_objects(self):
+ x = np.array(3)
+ assert_raises(TypeError, random.shuffle, x)
+
def test_permutation(self):
random.seed(self.seed)
alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]