summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-03-02 09:45:59 -0800
committerSebastian Berg <sebastian@sipsolutions.net>2021-02-17 23:06:47 -0600
commite01d89471b0c4a69f7d1e868af31af7220f90318 (patch)
tree652796235fc2e500f133dfccf8c344c99e027a71 /numpy
parenta1ec3385b8390046029591ba51c94c17cd31f5ae (diff)
downloadnumpy-e01d89471b0c4a69f7d1e868af31af7220f90318.tar.gz
TST: Add test for nonzero and copyswapn (through advanced indexing)
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/tests/test_indexing.py24
-rw-r--r--numpy/core/tests/test_numeric.py21
2 files changed, 45 insertions, 0 deletions
diff --git a/numpy/core/tests/test_indexing.py b/numpy/core/tests/test_indexing.py
index 73dbc429c..57b1f3827 100644
--- a/numpy/core/tests/test_indexing.py
+++ b/numpy/core/tests/test_indexing.py
@@ -563,6 +563,30 @@ class TestIndexing:
with pytest.raises(IndexError):
arr[(index,) * num] = 1.
+ def test_structured_advanced_indexing(self):
+ # Test that copyswap(n) used by integer array indexing is threadsafe
+ # for structured datatypes, see gh-15387. This test can behave randomly.
+ from concurrent.futures import ThreadPoolExecutor
+
+ # Create a deeply nested dtype to make a failure more likely:
+ dt = np.dtype([("", "f8")])
+ dt = np.dtype([("", dt)] * 2)
+ dt = np.dtype([("", dt)] * 2)
+ # The array should be large enough to likely run into threading issues
+ arr = np.random.uniform(size=(6000, 8)).view(dt)[:, 0]
+
+ rng = np.random.default_rng()
+ def func(arr):
+ indx = rng.integers(0, len(arr), size=6000, dtype=np.intp)
+ arr[indx]
+
+ tpe = ThreadPoolExecutor(max_workers=8)
+ futures = [tpe.submit(func, arr) for _ in range(10)]
+ for f in futures:
+ f.result()
+
+ assert arr.dtype is dt
+
class TestFieldIndexing:
def test_scalar_return_type(self):
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 06511822e..8d3cec708 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1536,6 +1536,27 @@ class TestNonzero:
a = np.array([[ThrowsAfter(15)]]*10)
assert_raises(ValueError, np.nonzero, a)
+ def test_structured_threadsafety(self):
+ # Nonzero (and some other functions) should be threadsafe for
+ # structured datatypes, see gh-15387. This test can behave randomly.
+ from concurrent.futures import ThreadPoolExecutor
+
+ # Create a deeply nested dtype to make a failure more likely:
+ dt = np.dtype([("", "f8")])
+ dt = np.dtype([("", dt)])
+ dt = np.dtype([("", dt)] * 2)
+ # The array should be large enough to likely run into threading issues
+ arr = np.random.uniform(size=(5000, 4)).view(dt)[:, 0]
+ def func(arr):
+ arr.nonzero()
+
+ tpe = ThreadPoolExecutor(max_workers=8)
+ futures = [tpe.submit(func, arr) for _ in range(10)]
+ for f in futures:
+ f.result()
+
+ assert arr.dtype is dt
+
class TestIndex:
def test_boolean(self):