summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2014-06-01 08:12:47 -0700
committerAndy McCurdy <andy@andymccurdy.com>2014-06-01 08:14:27 -0700
commitb36ab87d22dc9b4f0109d17f6f3c7740bb48a7fe (patch)
tree1f53a914aa23a569da2b39f93e7bdf40f02ed2a2
parentbe6e501d046c9fb7ef86a47ec3b276a22997e900 (diff)
downloadredis-py-b36ab87d22dc9b4f0109d17f6f3c7740bb48a7fe.tar.gz
updated Lock class:
* now uses unique string tokens to claim lock ownership * added extend() method to extend the timeout on an already acquired lock
-rwxr-xr-xredis/client.py10
-rw-r--r--redis/exceptions.py5
-rw-r--r--redis/lock.py158
-rw-r--r--tests/test_lock.py124
4 files changed, 255 insertions, 42 deletions
diff --git a/redis/client.py b/redis/client.py
index efce1d8..502549a 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -476,7 +476,7 @@ class StrictRedis(object):
except WatchError:
continue
- def lock(self, name, timeout=None, sleep=0.1):
+ def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None):
"""
Return a new Lock object using key ``name`` that mimics
the behavior of threading.Lock.
@@ -487,8 +487,14 @@ class StrictRedis(object):
``sleep`` indicates the amount of time to sleep per loop iteration
when the lock is in blocking mode and another client is currently
holding the lock.
+
+ ``blocking_timeout`` indicates the maximum amount of time in seconds to
+ spend trying to acquire the lock. A value of ``None`` indicates
+ continue trying forever. ``blocking_timeout`` can be specified as a
+ float or integer, both representing the number of seconds to wait.
"""
- return Lock(self, name, timeout=timeout, sleep=sleep)
+ return Lock(self, name, timeout=timeout, sleep=sleep,
+ blocking_timeout=blocking_timeout)
def pubsub(self, **kwargs):
"""
diff --git a/redis/exceptions.py b/redis/exceptions.py
index 6ab5a6e..a8518c7 100644
--- a/redis/exceptions.py
+++ b/redis/exceptions.py
@@ -64,5 +64,8 @@ class ReadOnlyError(ResponseError):
pass
-class LockError(RedisError):
+class LockError(RedisError, ValueError):
+ "Errors acquiring or releasing a lock"
+ # NOTE: For backwards compatability, this class derives from ValueError.
+ # This was originally chosen to behave like threading.Lock.
pass
diff --git a/redis/lock.py b/redis/lock.py
new file mode 100644
index 0000000..5c2ba2b
--- /dev/null
+++ b/redis/lock.py
@@ -0,0 +1,158 @@
+import time as mod_time
+import uuid
+from redis.exceptions import LockError, WatchError
+from redis._compat import long
+
+
+class Lock(object):
+ """
+ A shared, distributed Lock. Using Redis for locking allows the Lock
+ to be shared across processes and/or machines.
+
+ It's left to the user to resolve deadlock issues and make sure
+ multiple clients play nicely together.
+ """
+ def __init__(self, redis, name, timeout=None, sleep=0.1,
+ blocking=True, blocking_timeout=None):
+ """
+ Create a new Lock instance named ``name`` using the Redis client
+ supplied by ``redis``.
+
+ ``timeout`` indicates a maximum life for the lock.
+ By default, it will remain locked until release() is called.
+ ``timeout`` can be specified as a float or integer, both representing
+ the number of seconds to wait.
+
+ ``sleep`` indicates the amount of time to sleep per loop iteration
+ when the lock is in blocking mode and another client is currently
+ holding the lock.
+
+ ``blocking`` indicates whether calling ``acquire`` should block until
+ the lock has been acquired or to fail immediately, causing ``acquire``
+ to return False and the lock not being acquired. Defaults to True.
+ Note this value can be overridden by passing a ``blocking``
+ argument to ``acquire``.
+
+ ``blocking_timeout`` indicates the maximum amount of time in seconds to
+ spend trying to acquire the lock. A value of ``None`` indicates
+ continue trying forever. ``blocking_timeout`` can be specified as a
+ float or integer, both representing the number of seconds to wait.
+ """
+ self.redis = redis
+ self.name = name
+ self.timeout = timeout
+ self.sleep = sleep
+ self.blocking = blocking
+ self.blocking_timeout = blocking_timeout
+ self.token = None
+ if self.timeout and self.sleep > self.timeout:
+ raise LockError("'sleep' must be less than 'timeout'")
+
+ def __enter__(self):
+ # force blocking, as otherwise the user would have to check whether
+ # the lock was actually acquired or not.
+ self.acquire(blocking=True)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.release()
+
+ def acquire(self, blocking=None, blocking_timeout=None):
+ """
+ Use Redis to hold a shared, distributed lock named ``name``.
+ Returns True once the lock is acquired.
+
+ If ``blocking`` is False, always return immediately. If the lock
+ was acquired, return True, otherwise return False.
+
+ ``blocking_timeout`` specifies the maximum number of seconds to
+ wait trying to acquire the lock.
+ """
+ sleep = self.sleep
+ token = uuid.uuid1().hex
+ if blocking is None:
+ blocking = self.blocking
+ if blocking_timeout is None:
+ blocking_timeout = self.blocking_timeout
+ stop_trying_at = None
+ if self.blocking_timeout is not None:
+ stop_trying_at = mod_time.time() + self.blocking_timeout
+ while 1:
+ if self.do_acquire(token):
+ self.token = token
+ return True
+ if not blocking:
+ return False
+ if stop_trying_at is not None and mod_time.time() > stop_trying_at:
+ return False
+ mod_time.sleep(sleep)
+
+ def do_acquire(self, token):
+ if self.redis.setnx(self.name, token):
+ if self.timeout:
+ if isinstance(self.timeout, (int, long)):
+ self.redis.expire(self.name, self.timeout)
+ else:
+ # convert float to milliseconds
+ timeout = int(self.timeout * 1000)
+ self.redis.pexpire(self.name, timeout)
+ return True
+ return False
+
+ def release(self):
+ "Releases the already acquired lock"
+ if self.token is None:
+ raise LockError("Cannot release an unlocked lock")
+ try:
+ self.do_release()
+ finally:
+ self.token = None
+
+ def do_release(self):
+ name = self.name
+ token = self.token
+
+ def execute_release(pipe):
+ lock_value = pipe.get(name)
+ if lock_value != token:
+ raise LockError("Cannot release a lock that's no longer owned")
+ pipe.delete(name)
+
+ self.redis.transaction(execute_release, name)
+
+ def extend(self, additional_time):
+ """
+ Adds more time to an already acquired lock.
+
+ ``additional_time`` can be specified as an integer or a float, both
+ representing the number of seconds to add.
+ """
+ if self.token is None:
+ raise LockError("Cannot extend an unlocked lock")
+ if self.timeout is None:
+ raise LockError("Cannot extend a lock with no timeout")
+ return self.do_extend(additional_time)
+
+ def do_extend(self, additional_time):
+ pipe = self.redis.pipeline()
+ pipe.watch(self.name)
+ lock_value = pipe.get(self.name)
+ if lock_value != self.token:
+ raise LockError("Cannot extend a lock that's no longer owned")
+ expiration = pipe.pttl(self.name)
+ if expiration is None or expiration < 0:
+ # Redis evicted the lock key between the previous get() and now
+ # we'll handle this when we call pexpire()
+ expiration = 0
+ pipe.multi()
+ pipe.pexpire(self.name, expiration + int(additional_time * 1000))
+
+ try:
+ response = pipe.execute()
+ except WatchError:
+ # someone else acquired the lock
+ raise LockError("Cannot extend a lock that's no longer owned")
+ if not response[0]:
+ # pexpire returns False if the key doesn't exist
+ raise LockError("Cannot extend a lock that's no longer owned")
+ return True
diff --git a/tests/test_lock.py b/tests/test_lock.py
index a31f5f1..304dbe1 100644
--- a/tests/test_lock.py
+++ b/tests/test_lock.py
@@ -2,60 +2,106 @@ from __future__ import with_statement
import pytest
import time
-from redis.lock import Lock, LockError
+from redis.lock import LockError
class TestLock(object):
- def test_lock(self, r):
- lock = r.lock('foo')
- assert lock.acquire()
- assert r['foo'] == str(Lock.LOCK_FOREVER).encode()
+ def test_lock(self, sr):
+ lock = sr.lock('foo')
+ assert lock.acquire(blocking=False)
+ assert sr.get('foo') == lock.token
+ assert sr.ttl('foo') == -1
lock.release()
- assert r.get('foo') is None
+ assert sr.get('foo') is None
- def test_competing_locks(self, r):
- lock1 = r.lock('foo')
- lock2 = r.lock('foo')
- assert lock1.acquire()
+ def test_competing_locks(self, sr):
+ lock1 = sr.lock('foo')
+ lock2 = sr.lock('foo')
+ assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
- assert lock2.acquire()
+ assert lock2.acquire(blocking=False)
assert not lock1.acquire(blocking=False)
lock2.release()
- def test_timeouts(self, r):
- lock1 = r.lock('foo', timeout=1)
- lock2 = r.lock('foo')
- assert lock1.acquire()
- now = time.time()
- assert now < lock1.acquired_until < now + 1
- assert lock1.acquired_until == float(r['foo'])
- assert not lock2.acquire(blocking=False)
- time.sleep(2) # need to wait up to 2 seconds for lock to timeout
- assert lock2.acquire(blocking=False)
- lock2.release()
+ def test_timeout(self, sr):
+ lock = sr.lock('foo', timeout=10)
+ assert lock.acquire(blocking=False)
+ assert 0 < sr.ttl('foo') <= 10
+ lock.release()
- def test_non_blocking(self, r):
- lock1 = r.lock('foo')
+ def test_float_timeout(self, sr):
+ lock = sr.lock('foo', timeout=9.5)
+ assert lock.acquire(blocking=False)
+ assert 0 < sr.pttl('foo') <= 9500
+ lock.release()
+
+ def test_blocking_timeout(self, sr):
+ lock1 = sr.lock('foo')
assert lock1.acquire(blocking=False)
- assert lock1.acquired_until
+ lock2 = sr.lock('foo', blocking_timeout=0.2)
+ start = time.time()
+ assert not lock2.acquire()
+ assert (time.time() - start) > 0.2
lock1.release()
- assert lock1.acquired_until is None
-
- def test_context_manager(self, r):
- with r.lock('foo'):
- assert r['foo'] == str(Lock.LOCK_FOREVER).encode()
- assert r.get('foo') is None
- def test_float_timeout(self, r):
- lock1 = r.lock('foo', timeout=1.5)
- lock2 = r.lock('foo', timeout=1.5)
- assert lock1.acquire()
- assert not lock2.acquire(blocking=False)
- lock1.release()
+ def test_context_manager(self, sr):
+ with sr.lock('foo') as lock:
+ assert sr.get('foo') == lock.token
+ assert sr.get('foo') is None
- def test_high_sleep_raises_error(self, r):
+ def test_high_sleep_raises_error(self, sr):
"If sleep is higher than timeout, it should raise an error"
with pytest.raises(LockError):
- r.lock('foo', timeout=1, sleep=2)
+ sr.lock('foo', timeout=1, sleep=2)
+
+ def test_releasing_unlocked_lock_raises_error(self, sr):
+ lock = sr.lock('foo')
+ with pytest.raises(LockError):
+ lock.release()
+
+ def test_releasing_lock_no_longer_owned_raises_error(self, sr):
+ lock = sr.lock('foo')
+ lock.acquire(blocking=False)
+ # manually change the token
+ sr.set('foo', 'a')
+ with pytest.raises(LockError):
+ lock.release()
+ # even though we errored, the token is still cleared
+ assert lock.token is None
+
+ def test_extend_lock(self, sr):
+ lock = sr.lock('foo', timeout=10)
+ assert lock.acquire(blocking=True)
+ assert 0 < sr.pttl('foo') <= 10000
+ assert lock.extend(10)
+ assert 10000 < sr.pttl('foo') < 20000
+ lock.release()
+
+ def test_extend_lock_float(self, sr):
+ lock = sr.lock('foo', timeout=10.0)
+ assert lock.acquire(blocking=True)
+ assert 0 < sr.pttl('foo') <= 10000
+ assert lock.extend(10.0)
+ assert 10000 < sr.pttl('foo') < 20000
+ lock.release()
+
+ def test_extending_unlocked_lock_raises_error(self, sr):
+ lock = sr.lock('foo', timeout=10)
+ with pytest.raises(LockError):
+ lock.extend(10)
+
+ def test_extending_lock_with_no_timeout_raises_error(self, sr):
+ lock = sr.lock('foo')
+ assert lock.acquire(blocking=False)
+ with pytest.raises(LockError):
+ lock.extend(10)
+ lock.release()
+
+ def test_extending_lock_no_longer_owned_raises_error(self, sr):
+ lock = sr.lock('foo')
+ assert lock.acquire(blocking=False)
+ sr.set('foo', 'a')
+ with pytest.raises(LockError):
+ lock.extend(10)