summaryrefslogtreecommitdiff
path: root/taskflow/tests/unit/test_utils_lock_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'taskflow/tests/unit/test_utils_lock_utils.py')
-rw-r--r--taskflow/tests/unit/test_utils_lock_utils.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/taskflow/tests/unit/test_utils_lock_utils.py b/taskflow/tests/unit/test_utils_lock_utils.py
index 6846b13..80037b8 100644
--- a/taskflow/tests/unit/test_utils_lock_utils.py
+++ b/taskflow/tests/unit/test_utils_lock_utils.py
@@ -17,6 +17,7 @@
# under the License.
import collections
+import threading
import time
from concurrent import futures
@@ -89,6 +90,57 @@ class ReadWriteLockTest(test.TestCase):
self.assertRaises(RuntimeError, blow_up)
self.assertFalse(lock.owner)
+ def test_double_reader_abort(self):
+ lock = lock_utils.ReaderWriterLock()
+ activated = collections.deque()
+
+ def double_bad_reader():
+ with lock.read_lock():
+ with lock.read_lock():
+ raise RuntimeError("Broken")
+
+ def happy_writer():
+ with lock.write_lock():
+ activated.append(lock.owner)
+
+ with futures.ThreadPoolExecutor(max_workers=20) as e:
+ for i in range(0, 20):
+ if i % 2 == 0:
+ e.submit(double_bad_reader)
+ else:
+ e.submit(happy_writer)
+
+ self.assertEqual(10, len([a for a in activated if a == 'w']))
+
+ def test_double_reader_writer(self):
+ lock = lock_utils.ReaderWriterLock()
+ activated = collections.deque()
+ active = threading.Event()
+
+ def double_reader():
+ with lock.read_lock():
+ active.set()
+ while lock.pending_writers == 0:
+ time.sleep(0.001)
+ with lock.read_lock():
+ activated.append(lock.owner)
+
+ def happy_writer():
+ with lock.write_lock():
+ activated.append(lock.owner)
+
+ reader = threading.Thread(target=double_reader)
+ reader.start()
+ active.wait()
+
+ writer = threading.Thread(target=happy_writer)
+ writer.start()
+
+ reader.join()
+ writer.join()
+ self.assertEqual(2, len(activated))
+ self.assertEqual(['r', 'w'], list(activated))
+
def test_reader_chaotic(self):
lock = lock_utils.ReaderWriterLock()
activated = collections.deque()