diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2014-06-01 08:12:47 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2014-06-01 08:14:27 -0700 |
commit | b36ab87d22dc9b4f0109d17f6f3c7740bb48a7fe (patch) | |
tree | 1f53a914aa23a569da2b39f93e7bdf40f02ed2a2 | |
parent | be6e501d046c9fb7ef86a47ec3b276a22997e900 (diff) | |
download | redis-py-b36ab87d22dc9b4f0109d17f6f3c7740bb48a7fe.tar.gz |
updated Lock class:
* now uses unique string tokens to claim lock ownership
* added extend() method to extend the timeout on an already acquired lock
-rwxr-xr-x | redis/client.py | 10 | ||||
-rw-r--r-- | redis/exceptions.py | 5 | ||||
-rw-r--r-- | redis/lock.py | 158 | ||||
-rw-r--r-- | tests/test_lock.py | 124 |
4 files changed, 255 insertions, 42 deletions
diff --git a/redis/client.py b/redis/client.py index efce1d8..502549a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -476,7 +476,7 @@ class StrictRedis(object): except WatchError: continue - def lock(self, name, timeout=None, sleep=0.1): + def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None): """ Return a new Lock object using key ``name`` that mimics the behavior of threading.Lock. @@ -487,8 +487,14 @@ class StrictRedis(object): ``sleep`` indicates the amount of time to sleep per loop iteration when the lock is in blocking mode and another client is currently holding the lock. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. """ - return Lock(self, name, timeout=timeout, sleep=sleep) + return Lock(self, name, timeout=timeout, sleep=sleep, + blocking_timeout=blocking_timeout) def pubsub(self, **kwargs): """ diff --git a/redis/exceptions.py b/redis/exceptions.py index 6ab5a6e..a8518c7 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -64,5 +64,8 @@ class ReadOnlyError(ResponseError): pass -class LockError(RedisError): +class LockError(RedisError, ValueError): + "Errors acquiring or releasing a lock" + # NOTE: For backwards compatability, this class derives from ValueError. + # This was originally chosen to behave like threading.Lock. pass diff --git a/redis/lock.py b/redis/lock.py new file mode 100644 index 0000000..5c2ba2b --- /dev/null +++ b/redis/lock.py @@ -0,0 +1,158 @@ +import time as mod_time +import uuid +from redis.exceptions import LockError, WatchError +from redis._compat import long + + +class Lock(object): + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + def __init__(self, redis, name, timeout=None, sleep=0.1, + blocking=True, blocking_timeout=None): + """ + Create a new Lock instance named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + ``timeout`` can be specified as a float or integer, both representing + the number of seconds to wait. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + """ + self.redis = redis + self.name = name + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.token = None + if self.timeout and self.sleep > self.timeout: + raise LockError("'sleep' must be less than 'timeout'") + + def __enter__(self): + # force blocking, as otherwise the user would have to check whether + # the lock was actually acquired or not. + self.acquire(blocking=True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.release() + + def acquire(self, blocking=None, blocking_timeout=None): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + + ``blocking_timeout`` specifies the maximum number of seconds to + wait trying to acquire the lock. + """ + sleep = self.sleep + token = uuid.uuid1().hex + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if self.blocking_timeout is not None: + stop_trying_at = mod_time.time() + self.blocking_timeout + while 1: + if self.do_acquire(token): + self.token = token + return True + if not blocking: + return False + if stop_trying_at is not None and mod_time.time() > stop_trying_at: + return False + mod_time.sleep(sleep) + + def do_acquire(self, token): + if self.redis.setnx(self.name, token): + if self.timeout: + if isinstance(self.timeout, (int, long)): + self.redis.expire(self.name, self.timeout) + else: + # convert float to milliseconds + timeout = int(self.timeout * 1000) + self.redis.pexpire(self.name, timeout) + return True + return False + + def release(self): + "Releases the already acquired lock" + if self.token is None: + raise LockError("Cannot release an unlocked lock") + try: + self.do_release() + finally: + self.token = None + + def do_release(self): + name = self.name + token = self.token + + def execute_release(pipe): + lock_value = pipe.get(name) + if lock_value != token: + raise LockError("Cannot release a lock that's no longer owned") + pipe.delete(name) + + self.redis.transaction(execute_release, name) + + def extend(self, additional_time): + """ + Adds more time to an already acquired lock. + + ``additional_time`` can be specified as an integer or a float, both + representing the number of seconds to add. + """ + if self.token is None: + raise LockError("Cannot extend an unlocked lock") + if self.timeout is None: + raise LockError("Cannot extend a lock with no timeout") + return self.do_extend(additional_time) + + def do_extend(self, additional_time): + pipe = self.redis.pipeline() + pipe.watch(self.name) + lock_value = pipe.get(self.name) + if lock_value != self.token: + raise LockError("Cannot extend a lock that's no longer owned") + expiration = pipe.pttl(self.name) + if expiration is None or expiration < 0: + # Redis evicted the lock key between the previous get() and now + # we'll handle this when we call pexpire() + expiration = 0 + pipe.multi() + pipe.pexpire(self.name, expiration + int(additional_time * 1000)) + + try: + response = pipe.execute() + except WatchError: + # someone else acquired the lock + raise LockError("Cannot extend a lock that's no longer owned") + if not response[0]: + # pexpire returns False if the key doesn't exist + raise LockError("Cannot extend a lock that's no longer owned") + return True diff --git a/tests/test_lock.py b/tests/test_lock.py index a31f5f1..304dbe1 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -2,60 +2,106 @@ from __future__ import with_statement import pytest import time -from redis.lock import Lock, LockError +from redis.lock import LockError class TestLock(object): - def test_lock(self, r): - lock = r.lock('foo') - assert lock.acquire() - assert r['foo'] == str(Lock.LOCK_FOREVER).encode() + def test_lock(self, sr): + lock = sr.lock('foo') + assert lock.acquire(blocking=False) + assert sr.get('foo') == lock.token + assert sr.ttl('foo') == -1 lock.release() - assert r.get('foo') is None + assert sr.get('foo') is None - def test_competing_locks(self, r): - lock1 = r.lock('foo') - lock2 = r.lock('foo') - assert lock1.acquire() + def test_competing_locks(self, sr): + lock1 = sr.lock('foo') + lock2 = sr.lock('foo') + assert lock1.acquire(blocking=False) assert not lock2.acquire(blocking=False) lock1.release() - assert lock2.acquire() + assert lock2.acquire(blocking=False) assert not lock1.acquire(blocking=False) lock2.release() - def test_timeouts(self, r): - lock1 = r.lock('foo', timeout=1) - lock2 = r.lock('foo') - assert lock1.acquire() - now = time.time() - assert now < lock1.acquired_until < now + 1 - assert lock1.acquired_until == float(r['foo']) - assert not lock2.acquire(blocking=False) - time.sleep(2) # need to wait up to 2 seconds for lock to timeout - assert lock2.acquire(blocking=False) - lock2.release() + def test_timeout(self, sr): + lock = sr.lock('foo', timeout=10) + assert lock.acquire(blocking=False) + assert 0 < sr.ttl('foo') <= 10 + lock.release() - def test_non_blocking(self, r): - lock1 = r.lock('foo') + def test_float_timeout(self, sr): + lock = sr.lock('foo', timeout=9.5) + assert lock.acquire(blocking=False) + assert 0 < sr.pttl('foo') <= 9500 + lock.release() + + def test_blocking_timeout(self, sr): + lock1 = sr.lock('foo') assert lock1.acquire(blocking=False) - assert lock1.acquired_until + lock2 = sr.lock('foo', blocking_timeout=0.2) + start = time.time() + assert not lock2.acquire() + assert (time.time() - start) > 0.2 lock1.release() - assert lock1.acquired_until is None - - def test_context_manager(self, r): - with r.lock('foo'): - assert r['foo'] == str(Lock.LOCK_FOREVER).encode() - assert r.get('foo') is None - def test_float_timeout(self, r): - lock1 = r.lock('foo', timeout=1.5) - lock2 = r.lock('foo', timeout=1.5) - assert lock1.acquire() - assert not lock2.acquire(blocking=False) - lock1.release() + def test_context_manager(self, sr): + with sr.lock('foo') as lock: + assert sr.get('foo') == lock.token + assert sr.get('foo') is None - def test_high_sleep_raises_error(self, r): + def test_high_sleep_raises_error(self, sr): "If sleep is higher than timeout, it should raise an error" with pytest.raises(LockError): - r.lock('foo', timeout=1, sleep=2) + sr.lock('foo', timeout=1, sleep=2) + + def test_releasing_unlocked_lock_raises_error(self, sr): + lock = sr.lock('foo') + with pytest.raises(LockError): + lock.release() + + def test_releasing_lock_no_longer_owned_raises_error(self, sr): + lock = sr.lock('foo') + lock.acquire(blocking=False) + # manually change the token + sr.set('foo', 'a') + with pytest.raises(LockError): + lock.release() + # even though we errored, the token is still cleared + assert lock.token is None + + def test_extend_lock(self, sr): + lock = sr.lock('foo', timeout=10) + assert lock.acquire(blocking=True) + assert 0 < sr.pttl('foo') <= 10000 + assert lock.extend(10) + assert 10000 < sr.pttl('foo') < 20000 + lock.release() + + def test_extend_lock_float(self, sr): + lock = sr.lock('foo', timeout=10.0) + assert lock.acquire(blocking=True) + assert 0 < sr.pttl('foo') <= 10000 + assert lock.extend(10.0) + assert 10000 < sr.pttl('foo') < 20000 + lock.release() + + def test_extending_unlocked_lock_raises_error(self, sr): + lock = sr.lock('foo', timeout=10) + with pytest.raises(LockError): + lock.extend(10) + + def test_extending_lock_with_no_timeout_raises_error(self, sr): + lock = sr.lock('foo') + assert lock.acquire(blocking=False) + with pytest.raises(LockError): + lock.extend(10) + lock.release() + + def test_extending_lock_no_longer_owned_raises_error(self, sr): + lock = sr.lock('foo') + assert lock.acquire(blocking=False) + sr.set('foo', 'a') + with pytest.raises(LockError): + lock.extend(10) |