summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2014-06-01 21:09:42 -0700
committerAndy McCurdy <andy@andymccurdy.com>2014-06-01 21:09:42 -0700
commit26c56b9d816c9d1cc1393c04b04b4f1d688f7353 (patch)
treee1249fec5037cb3e78c856a97aeca4e3b66d9f11
parentb36ab87d22dc9b4f0109d17f6f3c7740bb48a7fe (diff)
downloadredis-py-26c56b9d816c9d1cc1393c04b04b4f1d688f7353.tar.gz
add a lock implementation using Lua scripts.
-rwxr-xr-xredis/client.py22
-rw-r--r--redis/lock.py102
-rw-r--r--tests/test_lock.py110
3 files changed, 197 insertions, 37 deletions
diff --git a/redis/client.py b/redis/client.py
index 502549a..2abd675 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -9,7 +9,7 @@ from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys,
itervalues, izip, long, nativestr, unicode)
from redis.connection import (ConnectionPool, UnixDomainSocketConnection,
SSLConnection, Token)
-from redis.lock import Lock
+from redis.lock import Lock, LuaLock
from redis.exceptions import (
ConnectionError,
DataError,
@@ -433,6 +433,7 @@ class StrictRedis(object):
})
connection_pool = ConnectionPool(**kwargs)
self.connection_pool = connection_pool
+ self._use_lua_lock = None
self.response_callbacks = self.__class__.RESPONSE_CALLBACKS.copy()
@@ -476,7 +477,8 @@ class StrictRedis(object):
except WatchError:
continue
- def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None):
+ def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None,
+ lock_class=None):
"""
Return a new Lock object using key ``name`` that mimics
the behavior of threading.Lock.
@@ -492,9 +494,21 @@ class StrictRedis(object):
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
float or integer, both representing the number of seconds to wait.
+
+ ``lock_class`` forces the specified lock implementation.
"""
- return Lock(self, name, timeout=timeout, sleep=sleep,
- blocking_timeout=blocking_timeout)
+ if lock_class is None:
+ if self._use_lua_lock is None:
+ # the first time .lock() is called, determine if we can use
+ # Lua by attempting to register the necessary scripts
+ try:
+ LuaLock.register_scripts(self)
+ self._use_lua_lock = True
+ except ResponseError:
+ self._use_lua_lock = False
+ lock_class = self._use_lua_lock and LuaLock or Lock
+ return lock_class(self, name, timeout=timeout, sleep=sleep,
+ blocking_timeout=blocking_timeout)
def pubsub(self, **kwargs):
"""
diff --git a/redis/lock.py b/redis/lock.py
index 5c2ba2b..8008d22 100644
--- a/redis/lock.py
+++ b/redis/lock.py
@@ -1,7 +1,7 @@
import time as mod_time
import uuid
from redis.exceptions import LockError, WatchError
-from redis._compat import long
+from redis._compat import b
class Lock(object):
@@ -69,7 +69,7 @@ class Lock(object):
wait trying to acquire the lock.
"""
sleep = self.sleep
- token = uuid.uuid1().hex
+ token = b(uuid.uuid1().hex)
if blocking is None:
blocking = self.blocking
if blocking_timeout is None:
@@ -90,12 +90,9 @@ class Lock(object):
def do_acquire(self, token):
if self.redis.setnx(self.name, token):
if self.timeout:
- if isinstance(self.timeout, (int, long)):
- self.redis.expire(self.name, self.timeout)
- else:
- # convert float to milliseconds
- timeout = int(self.timeout * 1000)
- self.redis.pexpire(self.name, timeout)
+ # convert to milliseconds
+ timeout = int(self.timeout * 1000)
+ self.redis.pexpire(self.name, timeout)
return True
return False
@@ -156,3 +153,92 @@ class Lock(object):
# pexpire returns False if the key doesn't exist
raise LockError("Cannot extend a lock that's no longer owned")
return True
+
+
+class LuaLock(Lock):
+ """
+ A lock implementation that uses Lua scripts rather than pipelines
+ and watches.
+ """
+ lua_acquire = None
+ lua_release = None
+ lua_extend = None
+
+ # KEYS[1] - lock name
+ # ARGV[1] - token
+ # ARGV[2] - timeout in milliseconds
+ # return 1 if lock was acquired, otherwise 0
+ LUA_ACQUIRE_SCRIPT = """
+ if redis.call('setnx', KEYS[1], ARGV[1]) == 1 then
+ if ARGV[2] ~= '' then
+ redis.call('pexpire', KEYS[1], ARGV[2])
+ end
+ return 1
+ end
+ return 0
+ """
+
+ # KEYS[1] - lock name
+ # ARGS[1] - token
+ # return 1 if the lock was released, otherwise 0
+ LUA_RELEASE_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ redis.call('del', KEYS[1])
+ return 1
+ """
+
+ # KEYS[1] - lock name
+ # ARGS[1] - token
+ # ARGS[2] - additional milliseconds
+ # return 1 if the locks time was extended, otherwise 0
+ LUA_EXTEND_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ local expiration = redis.call('pttl', KEYS[1])
+ if not expiration then
+ expiration = 0
+ end
+ if expiration < 0 then
+ return 0
+ end
+ redis.call('pexpire', KEYS[1], expiration + ARGV[2])
+ return 1
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(LuaLock, self).__init__(*args, **kwargs)
+ LuaLock.register_scripts(self.redis)
+
+ @classmethod
+ def register_scripts(cls, redis):
+ if cls.lua_acquire is None:
+ cls.lua_acquire = redis.register_script(cls.LUA_ACQUIRE_SCRIPT)
+ if cls.lua_release is None:
+ cls.lua_release = redis.register_script(cls.LUA_RELEASE_SCRIPT)
+ if cls.lua_extend is None:
+ cls.lua_extend = redis.register_script(cls.LUA_EXTEND_SCRIPT)
+
+ def do_acquire(self, token):
+ timeout = self.timeout and int(self.timeout * 1000) or ''
+ return bool(self.lua_acquire(keys=[self.name],
+ args=[token, timeout],
+ client=self.redis))
+
+ def do_release(self):
+ if not bool(self.lua_release(keys=[self.name],
+ args=[self.token],
+ client=self.redis)):
+ raise LockError("Cannot release a lock that's no longer owned")
+
+ def do_extend(self, additional_time):
+ additional_time = int(additional_time * 1000)
+ if not bool(self.lua_extend(keys=[self.name],
+ args=[self.token, additional_time],
+ client=self.redis)):
+ raise LockError("Cannot extend a lock that's no longer owned")
+ return True
diff --git a/tests/test_lock.py b/tests/test_lock.py
index 304dbe1..028f9a6 100644
--- a/tests/test_lock.py
+++ b/tests/test_lock.py
@@ -2,13 +2,19 @@ from __future__ import with_statement
import pytest
import time
-from redis.lock import LockError
+from redis.exceptions import LockError, ResponseError
+from redis.lock import Lock, LuaLock
class TestLock(object):
+ lock_class = Lock
+
+ def get_lock(self, redis, *args, **kwargs):
+ kwargs['lock_class'] = self.lock_class
+ return redis.lock(*args, **kwargs)
def test_lock(self, sr):
- lock = sr.lock('foo')
+ lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
assert sr.get('foo') == lock.token
assert sr.ttl('foo') == -1
@@ -16,8 +22,8 @@ class TestLock(object):
assert sr.get('foo') is None
def test_competing_locks(self, sr):
- lock1 = sr.lock('foo')
- lock2 = sr.lock('foo')
+ lock1 = self.get_lock(sr, 'foo')
+ lock2 = self.get_lock(sr, 'foo')
assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
@@ -26,43 +32,45 @@ class TestLock(object):
lock2.release()
def test_timeout(self, sr):
- lock = sr.lock('foo', timeout=10)
+ lock = self.get_lock(sr, 'foo', timeout=10)
assert lock.acquire(blocking=False)
- assert 0 < sr.ttl('foo') <= 10
+ assert 8 < sr.ttl('foo') <= 10
lock.release()
def test_float_timeout(self, sr):
- lock = sr.lock('foo', timeout=9.5)
+ lock = self.get_lock(sr, 'foo', timeout=9.5)
assert lock.acquire(blocking=False)
- assert 0 < sr.pttl('foo') <= 9500
+ assert 8 < sr.pttl('foo') <= 9500
lock.release()
def test_blocking_timeout(self, sr):
- lock1 = sr.lock('foo')
+ lock1 = self.get_lock(sr, 'foo')
assert lock1.acquire(blocking=False)
- lock2 = sr.lock('foo', blocking_timeout=0.2)
+ lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2)
start = time.time()
assert not lock2.acquire()
assert (time.time() - start) > 0.2
lock1.release()
def test_context_manager(self, sr):
- with sr.lock('foo') as lock:
+ # blocking_timeout prevents a deadlock if the lock can't be acquired
+ # for some reason
+ with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock:
assert sr.get('foo') == lock.token
assert sr.get('foo') is None
def test_high_sleep_raises_error(self, sr):
"If sleep is higher than timeout, it should raise an error"
with pytest.raises(LockError):
- sr.lock('foo', timeout=1, sleep=2)
+ self.get_lock(sr, 'foo', timeout=1, sleep=2)
def test_releasing_unlocked_lock_raises_error(self, sr):
- lock = sr.lock('foo')
+ lock = self.get_lock(sr, 'foo')
with pytest.raises(LockError):
lock.release()
def test_releasing_lock_no_longer_owned_raises_error(self, sr):
- lock = sr.lock('foo')
+ lock = self.get_lock(sr, 'foo')
lock.acquire(blocking=False)
# manually change the token
sr.set('foo', 'a')
@@ -72,36 +80,88 @@ class TestLock(object):
assert lock.token is None
def test_extend_lock(self, sr):
- lock = sr.lock('foo', timeout=10)
- assert lock.acquire(blocking=True)
- assert 0 < sr.pttl('foo') <= 10000
+ lock = self.get_lock(sr, 'foo', timeout=10)
+ assert lock.acquire(blocking=False)
+ assert 8000 < sr.pttl('foo') <= 10000
assert lock.extend(10)
- assert 10000 < sr.pttl('foo') < 20000
+ assert 16000 < sr.pttl('foo') <= 20000
lock.release()
def test_extend_lock_float(self, sr):
- lock = sr.lock('foo', timeout=10.0)
- assert lock.acquire(blocking=True)
- assert 0 < sr.pttl('foo') <= 10000
+ lock = self.get_lock(sr, 'foo', timeout=10.0)
+ assert lock.acquire(blocking=False)
+ assert 8000 < sr.pttl('foo') <= 10000
assert lock.extend(10.0)
- assert 10000 < sr.pttl('foo') < 20000
+ assert 16000 < sr.pttl('foo') <= 20000
lock.release()
def test_extending_unlocked_lock_raises_error(self, sr):
- lock = sr.lock('foo', timeout=10)
+ lock = self.get_lock(sr, 'foo', timeout=10)
with pytest.raises(LockError):
lock.extend(10)
def test_extending_lock_with_no_timeout_raises_error(self, sr):
- lock = sr.lock('foo')
+ lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
with pytest.raises(LockError):
lock.extend(10)
lock.release()
def test_extending_lock_no_longer_owned_raises_error(self, sr):
- lock = sr.lock('foo')
+ lock = self.get_lock(sr, 'foo')
assert lock.acquire(blocking=False)
sr.set('foo', 'a')
with pytest.raises(LockError):
lock.extend(10)
+
+
+class TestLuaLock(TestLock):
+ lock_class = LuaLock
+
+
+class TestLockClassSelection(object):
+ def test_lock_class_argument(self, sr):
+ lock = sr.lock('foo', lock_class=Lock)
+ assert type(lock) == Lock
+ lock = sr.lock('foo', lock_class=LuaLock)
+ assert type(lock) == LuaLock
+
+ def test_cached_lualock_flag(self, sr):
+ try:
+ sr._use_lua_lock = True
+ lock = sr.lock('foo')
+ assert type(lock) == LuaLock
+ finally:
+ sr._use_lua_lock = None
+
+ def test_cached_lock_flag(self, sr):
+ try:
+ sr._use_lua_lock = False
+ lock = sr.lock('foo')
+ assert type(lock) == Lock
+ finally:
+ sr._use_lua_lock = None
+
+ def test_lua_compatible_server(self, sr, monkeypatch):
+ @classmethod
+ def mock_register(cls, redis):
+ return
+ monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
+ try:
+ lock = sr.lock('foo')
+ assert type(lock) == LuaLock
+ assert sr._use_lua_lock is True
+ finally:
+ sr._use_lua_lock = None
+
+ def test_lua_unavailable(self, sr, monkeypatch):
+ @classmethod
+ def mock_register(cls, redis):
+ raise ResponseError()
+ monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
+ try:
+ lock = sr.lock('foo')
+ assert type(lock) == Lock
+ assert sr._use_lua_lock is False
+ finally:
+ sr._use_lua_lock = None