summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--eventlet/semaphore.py184
-rw-r--r--tests/semaphore_test.py46
2 files changed, 212 insertions, 18 deletions
diff --git a/eventlet/semaphore.py b/eventlet/semaphore.py
index 771bc35..21b1e61 100644
--- a/eventlet/semaphore.py
+++ b/eventlet/semaphore.py
@@ -1,5 +1,10 @@
from eventlet import greenthread
from eventlet import hubs
+from eventlet import patcher
+
+
+threading = patcher.original('threading')
+
class Semaphore(object):
"""An unbounded semaphore.
@@ -24,15 +29,20 @@ class Semaphore(object):
if value < 0:
raise ValueError("Semaphore must be initialized with a positive "
"number, got %s" % value)
- self._waiters = set()
+ self.total_waiters = 0
+ self.lock = threading.Lock()
+ self.cond = threading.Condition(self.lock)
+ self.gt_waiters = {}
+ self.thr_holders = {}
+
def __repr__(self):
params = (self.__class__.__name__, hex(id(self)),
- self.counter, len(self._waiters))
+ self.counter, self.total_waiters)
return '<%s at %s c=%s _w[%s]>' % params
def __str__(self):
- params = (self.__class__.__name__, self.counter, len(self._waiters))
+ params = (self.__class__.__name__, self.counter, self.total_waiters)
return '<%s c=%s _w[%s]>' % params
def locked(self):
@@ -44,6 +54,120 @@ class Semaphore(object):
:class:`~eventlet.semaphore.CappedSemaphore`."""
return False
+ def _add_greenthread_waiter_for_thread(self):
+ """Add the current greenthread to list of waiters for
+ the current Thread.
+
+ Call with self.lock locked.
+ """
+ gt = greenthread.getcurrent()
+ thr = threading.current_thread()
+ wt = self.gt_waiters.setdefault(thr.ident, list())
+ wt.append(gt)
+
+ def _del_greenthread_waiter_for_thread(self):
+ """Delete the current greenthread from list of waiters
+ for the current Thread.
+
+ Call with self.lock locked.
+ """
+ gt = greenthread.getcurrent()
+ thr = threading.current_thread()
+ try:
+ self.gt_waiters[thr.ident].remove(gt)
+ if not self.gt_waiters[thr.ident]:
+ # Keep dict from growing.
+ del self.gt_waiters[thr.ident]
+ except (KeyError, ValueError):
+ # Was removed already.
+ pass
+
+ def _get_greenthread_waiter_for_thread(self):
+ """Get a waiting greenthread from the current Thread. If
+ none are waiting, return None.
+
+ Call with self.lock locked.
+ """
+ thr = threading.current_thread()
+ if thr.ident not in self.gt_waiters:
+ return None
+ waiters = self.gt_waiters[thr.ident]
+ # Must be at least one entry, since our key exists in
+ # the dict.
+ waiter = waiters.pop(0)
+ if not waiters:
+ # Keep dict from growing.
+ del self.gt_waiters[thr.ident]
+ return waiter
+
+ def _thread_has_greenthread_waiter(self):
+ """Does this thread have a greenthread waiting?"""
+ thr = threading.current_thread()
+ # We remove the key from the dict if we get to 0,
+ # so this works.
+ return thr.ident in self.gt_waiters
+
+ def _thread_has_hold(self):
+ """Has the current Thread acquired the Semaphore already?"""
+ thr = threading.current_thread()
+ # We remove the key from the dict if we get to 0,
+ # so this works.
+ return thr.ident in self.thr_holders
+
+ def _thread_add_holder(self):
+ """Increment the number of holds the current Thread has
+ on the Semaphore.
+ """
+ thr = threading.current_thread()
+ self.thr_holders.setdefault(thr.ident, 0)
+ self.thr_holders[thr.ident] += 1
+
+ def _thread_del_holder(self):
+ """Decrement the number of holds the current Thread has
+ on the Semaphore.
+ """
+ thr = threading.current_thread()
+ self.thr_holders[thr.ident] -= 1
+ if not self.thr_holders[thr.ident]:
+ # Keep dict from growing.
+ del self.thr_holders[thr.ident]
+
+ def _acquire(self):
+ """Block until we can acquire the Semaphore.
+
+ If the current Thread already has acquired the Semaphore,
+ one of 2 things is true:
+
+ 1) Another greenthread in the current Thread is trying
+ to acquire it.
+ 2) The current Thread is trying to acquire it again.
+
+ If #2 is true, it's a buggy application as it's reached a
+ deadlock condition. So, we assume #1 is true and switch
+ greenthreads.
+
+ Also switch greenthreads if there are currently no holders. An
+ acquire() might be coming from another greenthread.
+
+ Otherwise, if the current Thread does NOT have the lock, then another
+ Thread must hold it. We'll wait (using a Condition) to be
+ signaled to try again.
+
+ Call with self.lock locked. This call can potentially unlock
+ and re-lock it.
+ """
+ while self.locked():
+ if not self.thr_holders or self._thread_has_hold():
+ self._add_greenthread_waiter_for_thread()
+ self.lock.release()
+ try:
+ hubs.get_hub().switch()
+ finally:
+ self.lock.acquire()
+ self._del_greenthread_waiter_for_thread()
+ else:
+ self.cond.wait()
+
def acquire(self, blocking=True):
"""Acquire a semaphore.
@@ -64,14 +188,23 @@ class Semaphore(object):
same thing as when called without arguments, and return true."""
if not blocking and self.locked():
return False
- if self.counter <= 0:
- self._waiters.add(greenthread.getcurrent())
- try:
- while self.counter <= 0:
- hubs.get_hub().switch()
- finally:
- self._waiters.discard(greenthread.getcurrent())
- self.counter -= 1
+ if self.lock.acquire(blocking) is False:
+ return False
+ try:
+ # Check again while locked.
+ if self.locked():
+ if not blocking:
+ self.lock.release()
+ return False
+ self.total_waiters += 1
+ try:
+ self._acquire()
+ finally:
+ self.total_waiters -= 1
+ self._thread_add_holder()
+ self.counter -= 1
+ finally:
+ self.lock.release()
return True
def __enter__(self):
@@ -84,14 +217,28 @@ class Semaphore(object):
The *blocking* argument is for consistency with CappedSemaphore and is
ignored"""
- self.counter += 1
- if self._waiters:
+ self.lock.acquire()
+ try:
+ try:
+ self._thread_del_holder()
+ except KeyError:
+ pass
+ self.cond.notify()
+ has_waiter = self._thread_has_greenthread_waiter()
+ self.counter += 1
+ finally:
+ self.lock.release()
+ if has_waiter:
hubs.get_hub().schedule_call_global(0, self._do_acquire)
return True
def _do_acquire(self):
- if self._waiters and self.counter>0:
- waiter = self._waiters.pop()
+ self.lock.acquire()
+ try:
+ waiter = self._get_greenthread_waiter_for_thread()
+ finally:
+ self.lock.release()
+ if waiter:
waiter.switch()
def __exit__(self, typ, val, tb):
@@ -111,7 +258,7 @@ class Semaphore(object):
# positive means there are free items
# zero means there are no free items but nobody has requested one
# negative means there are requests for items, but no items
- return self.counter - len(self._waiters)
+ return self.counter - self.total_waiters
class BoundedSemaphore(Semaphore):
@@ -132,7 +279,10 @@ class BoundedSemaphore(Semaphore):
The *blocking* argument is for consistency with :class:`CappedSemaphore`
and is ignored"""
- if self.counter >= self.original_counter:
+ self.lock.acquire()
+ too_many = self.counter >= self.original_counter
+ self.lock.release()
+ if too_many:
raise ValueError, "Semaphore released too many times"
return super(BoundedSemaphore, self).release(blocking)
diff --git a/tests/semaphore_test.py b/tests/semaphore_test.py
index 64153ed..dee978a 100644
--- a/tests/semaphore_test.py
+++ b/tests/semaphore_test.py
@@ -2,6 +2,8 @@ import unittest
import eventlet
from eventlet import semaphore
from tests import LimitedTestCase
+from tests import patcher_test
+
class TestSemaphore(LimitedTestCase):
def test_bounded(self):
@@ -19,7 +21,7 @@ class TestSemaphore(LimitedTestCase):
self.assertEqual(3, sem.balance)
gt1.wait()
gt2.wait()
-
+
def test_bounded_with_zero_limit(self):
sem = semaphore.CappedSemaphore(0, 0)
gt = eventlet.spawn(sem.acquire)
@@ -27,5 +29,47 @@ class TestSemaphore(LimitedTestCase):
gt.wait()
+semaphore_tpool_code = """
+from eventlet import greenthread
+from eventlet import patcher
+patcher.monkey_patch(thread=True)
+from eventlet import tpool
+import threading
+
+
+lock = threading.Lock()
+info = dict(thr=20)
+
+
+def lock_test():
+ lock.acquire()
+ greenthread.sleep(0)
+ lock.release()
+
+
+def gt_runner(method, *args):
+ method(*args)
+ info['thr'] -= 1
+
+
+for x in range(10):
+ greenthread.spawn_n(gt_runner, tpool.execute, lock_test)
+ greenthread.spawn_n(gt_runner, lock_test)
+
+for x in xrange(20):
+ greenthread.sleep(0.5)
+ if not info['thr']:
+ break
+else:
+ print 'fail'
+"""
+
+
+class MonkeyPatchTester(patcher_test.ProcessBase):
+ def test_semaphore_with_monkey_patched_thread(self):
+ output, lines = self.run_script(semaphore_tpool_code)
+ self.assertEqual(lines, [''])
+
+
if __name__=='__main__':
unittest.main()