diff options
| -rw-r--r-- | eventlet/semaphore.py | 184 | ||||
| -rw-r--r-- | tests/semaphore_test.py | 46 |
2 files changed, 212 insertions, 18 deletions
diff --git a/eventlet/semaphore.py b/eventlet/semaphore.py index 771bc35..21b1e61 100644 --- a/eventlet/semaphore.py +++ b/eventlet/semaphore.py @@ -1,5 +1,10 @@ from eventlet import greenthread from eventlet import hubs +from eventlet import patcher + + +threading = patcher.original('threading') + class Semaphore(object): """An unbounded semaphore. @@ -24,15 +29,20 @@ class Semaphore(object): if value < 0: raise ValueError("Semaphore must be initialized with a positive " "number, got %s" % value) - self._waiters = set() + self.total_waiters = 0 + self.lock = threading.Lock() + self.cond = threading.Condition(self.lock) + self.gt_waiters = {} + self.thr_holders = {} + def __repr__(self): params = (self.__class__.__name__, hex(id(self)), - self.counter, len(self._waiters)) + self.counter, self.total_waiters) return '<%s at %s c=%s _w[%s]>' % params def __str__(self): - params = (self.__class__.__name__, self.counter, len(self._waiters)) + params = (self.__class__.__name__, self.counter, self.total_waiters) return '<%s c=%s _w[%s]>' % params def locked(self): @@ -44,6 +54,120 @@ class Semaphore(object): :class:`~eventlet.semaphore.CappedSemaphore`.""" return False + def _add_greenthread_waiter_for_thread(self): + """Add the current greenthread to list of waiters for + the current Thread. + + Call with self.lock locked. + """ + gt = greenthread.getcurrent() + thr = threading.current_thread() + wt = self.gt_waiters.setdefault(thr.ident, list()) + wt.append(gt) + + def _del_greenthread_waiter_for_thread(self): + """Delete the current greenthread from list of waiters + for the current Thread. + + Call with self.lock locked. + """ + gt = greenthread.getcurrent() + thr = threading.current_thread() + try: + self.gt_waiters[thr.ident].remove(gt) + if not self.gt_waiters[thr.ident]: + # Keep dict from growing. + del self.gt_waiters[thr.ident] + except (KeyError, ValueError): + # Was removed already. + pass + + def _get_greenthread_waiter_for_thread(self): + """Get a waiting greenthread from the current Thread. If + none are waiting, return None. + + Call with self.lock locked. + """ + thr = threading.current_thread() + if thr.ident not in self.gt_waiters: + return None + waiters = self.gt_waiters[thr.ident] + # Must be at least one entry, since our key exists in + # the dict. + waiter = waiters.pop(0) + if not waiters: + # Keep dict from growing. + del self.gt_waiters[thr.ident] + return waiter + + def _thread_has_greenthread_waiter(self): + """Does this thread have a greenthread waiting?""" + thr = threading.current_thread() + # We remove the key from the dict if we get to 0, + # so this works. + return thr.ident in self.gt_waiters + + def _thread_has_hold(self): + """Has the current Thread acquired the Semaphore already?""" + thr = threading.current_thread() + # We remove the key from the dict if we get to 0, + # so this works. + return thr.ident in self.thr_holders + + def _thread_add_holder(self): + """Increment the number of holds the current Thread has + on the Semaphore. + """ + thr = threading.current_thread() + self.thr_holders.setdefault(thr.ident, 0) + self.thr_holders[thr.ident] += 1 + + def _thread_del_holder(self): + """Decrement the number of holds the current Thread has + on the Semaphore. + """ + thr = threading.current_thread() + self.thr_holders[thr.ident] -= 1 + if not self.thr_holders[thr.ident]: + # Keep dict from growing. + del self.thr_holders[thr.ident] + + def _acquire(self): + """Block until we can acquire the Semaphore. + + If the current Thread already has acquired the Semaphore, + one of 2 things is true: + + 1) Another greenthread in the current Thread is trying + to acquire it. + 2) The current Thread is trying to acquire it again. + + If #2 is true, it's a buggy application as it's reached a + deadlock condition. So, we assume #1 is true and switch + greenthreads. + + Also switch greenthreads if there are currently no holders. An + acquire() might be coming from another greenthread. + + Otherwise, if the current Thread does NOT have the lock, then another + Thread must hold it. We'll wait (using a Condition) to be + signaled to try again. + + Call with self.lock locked. This call can potentially unlock + and re-lock it. + """ + while self.locked(): + if not self.thr_holders or self._thread_has_hold(): + self._add_greenthread_waiter_for_thread() + self.lock.release() + try: + hubs.get_hub().switch() + finally: + self.lock.acquire() + self._del_greenthread_waiter_for_thread() + else: + self.cond.wait() + def acquire(self, blocking=True): """Acquire a semaphore. @@ -64,14 +188,23 @@ class Semaphore(object): same thing as when called without arguments, and return true.""" if not blocking and self.locked(): return False - if self.counter <= 0: - self._waiters.add(greenthread.getcurrent()) - try: - while self.counter <= 0: - hubs.get_hub().switch() - finally: - self._waiters.discard(greenthread.getcurrent()) - self.counter -= 1 + if self.lock.acquire(blocking) is False: + return False + try: + # Check again while locked. + if self.locked(): + if not blocking: + self.lock.release() + return False + self.total_waiters += 1 + try: + self._acquire() + finally: + self.total_waiters -= 1 + self._thread_add_holder() + self.counter -= 1 + finally: + self.lock.release() return True def __enter__(self): @@ -84,14 +217,28 @@ class Semaphore(object): The *blocking* argument is for consistency with CappedSemaphore and is ignored""" - self.counter += 1 - if self._waiters: + self.lock.acquire() + try: + try: + self._thread_del_holder() + except KeyError: + pass + self.cond.notify() + has_waiter = self._thread_has_greenthread_waiter() + self.counter += 1 + finally: + self.lock.release() + if has_waiter: hubs.get_hub().schedule_call_global(0, self._do_acquire) return True def _do_acquire(self): - if self._waiters and self.counter>0: - waiter = self._waiters.pop() + self.lock.acquire() + try: + waiter = self._get_greenthread_waiter_for_thread() + finally: + self.lock.release() + if waiter: waiter.switch() def __exit__(self, typ, val, tb): @@ -111,7 +258,7 @@ class Semaphore(object): # positive means there are free items # zero means there are no free items but nobody has requested one # negative means there are requests for items, but no items - return self.counter - len(self._waiters) + return self.counter - self.total_waiters class BoundedSemaphore(Semaphore): @@ -132,7 +279,10 @@ class BoundedSemaphore(Semaphore): The *blocking* argument is for consistency with :class:`CappedSemaphore` and is ignored""" - if self.counter >= self.original_counter: + self.lock.acquire() + too_many = self.counter >= self.original_counter + self.lock.release() + if too_many: raise ValueError, "Semaphore released too many times" return super(BoundedSemaphore, self).release(blocking) diff --git a/tests/semaphore_test.py b/tests/semaphore_test.py index 64153ed..dee978a 100644 --- a/tests/semaphore_test.py +++ b/tests/semaphore_test.py @@ -2,6 +2,8 @@ import unittest import eventlet from eventlet import semaphore from tests import LimitedTestCase +from tests import patcher_test + class TestSemaphore(LimitedTestCase): def test_bounded(self): @@ -19,7 +21,7 @@ class TestSemaphore(LimitedTestCase): self.assertEqual(3, sem.balance) gt1.wait() gt2.wait() - + def test_bounded_with_zero_limit(self): sem = semaphore.CappedSemaphore(0, 0) gt = eventlet.spawn(sem.acquire) @@ -27,5 +29,47 @@ class TestSemaphore(LimitedTestCase): gt.wait() +semaphore_tpool_code = """ +from eventlet import greenthread +from eventlet import patcher +patcher.monkey_patch(thread=True) +from eventlet import tpool +import threading + + +lock = threading.Lock() +info = dict(thr=20) + + +def lock_test(): + lock.acquire() + greenthread.sleep(0) + lock.release() + + +def gt_runner(method, *args): + method(*args) + info['thr'] -= 1 + + +for x in range(10): + greenthread.spawn_n(gt_runner, tpool.execute, lock_test) + greenthread.spawn_n(gt_runner, lock_test) + +for x in xrange(20): + greenthread.sleep(0.5) + if not info['thr']: + break +else: + print 'fail' +""" + + +class MonkeyPatchTester(patcher_test.ProcessBase): + def test_semaphore_with_monkey_patched_thread(self): + output, lines = self.run_script(semaphore_tpool_code) + self.assertEqual(lines, ['']) + + if __name__=='__main__': unittest.main() |
