summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2014-07-04 18:54:17 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-08-04 23:31:21 +0200
commit763aeeaffcd1572d4da3d7f98f7627d0e271df26 (patch)
treea6251fa1095ff6f1923c8c26e11471f5cf4aea26
parente715bce009c05bbeb5819f2b1d0468c6b776e3e3 (diff)
downloadnumpy-763aeeaffcd1572d4da3d7f98f7627d0e271df26.tar.gz
BUG: wrong selection for orders falling into equal ranges
when orders are selected where the kth element falls into an equal range the the last stored pivot was not the kth element, this leads to losing the ordering of smaller orders as following selection steps can start at index 0 again instead of the at the offset of the last selection. Closes gh-4836
-rw-r--r--numpy/core/src/npysort/selection.c.src10
-rw-r--r--numpy/core/tests/test_multiarray.py18
2 files changed, 25 insertions, 3 deletions
diff --git a/numpy/core/src/npysort/selection.c.src b/numpy/core/src/npysort/selection.c.src
index b11753367..bd0d97153 100644
--- a/numpy/core/src/npysort/selection.c.src
+++ b/numpy/core/src/npysort/selection.c.src
@@ -379,7 +379,10 @@ int
/* move pivot into position */
SWAP(SORTEE(low), SORTEE(hh));
- store_pivot(hh, kth, pivots, npiv);
+ /* kth pivot stored later */
+ if (hh != kth) {
+ store_pivot(hh, kth, pivots, npiv);
+ }
if (hh >= kth)
high = hh - 1;
@@ -389,10 +392,11 @@ int
/* two elements */
if (high == low + 1) {
- if (@TYPE@_LT(v[IDX(high)], v[IDX(low)]))
+ if (@TYPE@_LT(v[IDX(high)], v[IDX(low)])) {
SWAP(SORTEE(high), SORTEE(low))
- store_pivot(low, kth, pivots, npiv);
+ }
}
+ store_pivot(kth, kth, pivots, npiv);
return 0;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index e1c1e26d9..fe210177a 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1143,6 +1143,12 @@ class TestMethods(TestCase):
d[i:].partition(0, kind=k)
assert_array_equal(d, tgt)
+ d = np.array([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 9])
+ kth = [0, 3, 19, 20]
+ assert_equal(np.partition(d, kth, kind=k)[kth], (0, 3, 7, 7))
+ assert_equal(d[np.argpartition(d, kth, kind=k)][kth], (0, 3, 7, 7))
+
d = np.array([2, 1])
d.partition(0, kind=k)
assert_raises(ValueError, d.partition, 2)
@@ -1332,6 +1338,18 @@ class TestMethods(TestCase):
assert_equal(np.partition(d, k)[k], tgt[k])
assert_equal(d[np.argpartition(d, k)][k], tgt[k])
+ def test_partition_fuzz(self):
+ # a few rounds of random data testing
+ for j in range(10, 30):
+ for i in range(1, j - 2):
+ d = np.arange(j)
+ np.random.shuffle(d)
+ d = d % np.random.randint(2, 30)
+ idx = np.random.randint(d.size)
+ kth = [0, idx, i, i + 1]
+ tgt = np.sort(d)[kth]
+ assert_array_equal(np.partition(d, kth)[kth], tgt,
+ err_msg="data: %r\n kth: %r" % (d, kth))
def test_flatten(self):
x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)