diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2014-06-01 21:09:42 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2014-06-01 21:09:42 -0700 |
commit | 26c56b9d816c9d1cc1393c04b04b4f1d688f7353 (patch) | |
tree | e1249fec5037cb3e78c856a97aeca4e3b66d9f11 | |
parent | b36ab87d22dc9b4f0109d17f6f3c7740bb48a7fe (diff) | |
download | redis-py-26c56b9d816c9d1cc1393c04b04b4f1d688f7353.tar.gz |
add a lock implementation using Lua scripts.
-rwxr-xr-x | redis/client.py | 22 | ||||
-rw-r--r-- | redis/lock.py | 102 | ||||
-rw-r--r-- | tests/test_lock.py | 110 |
3 files changed, 197 insertions, 37 deletions
diff --git a/redis/client.py b/redis/client.py index 502549a..2abd675 100755 --- a/redis/client.py +++ b/redis/client.py @@ -9,7 +9,7 @@ from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, itervalues, izip, long, nativestr, unicode) from redis.connection import (ConnectionPool, UnixDomainSocketConnection, SSLConnection, Token) -from redis.lock import Lock +from redis.lock import Lock, LuaLock from redis.exceptions import ( ConnectionError, DataError, @@ -433,6 +433,7 @@ class StrictRedis(object): }) connection_pool = ConnectionPool(**kwargs) self.connection_pool = connection_pool + self._use_lua_lock = None self.response_callbacks = self.__class__.RESPONSE_CALLBACKS.copy() @@ -476,7 +477,8 @@ class StrictRedis(object): except WatchError: continue - def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None): + def lock(self, name, timeout=None, sleep=0.1, blocking_timeout=None, + lock_class=None): """ Return a new Lock object using key ``name`` that mimics the behavior of threading.Lock. @@ -492,9 +494,21 @@ class StrictRedis(object): spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. """ - return Lock(self, name, timeout=timeout, sleep=sleep, - blocking_timeout=blocking_timeout) + if lock_class is None: + if self._use_lua_lock is None: + # the first time .lock() is called, determine if we can use + # Lua by attempting to register the necessary scripts + try: + LuaLock.register_scripts(self) + self._use_lua_lock = True + except ResponseError: + self._use_lua_lock = False + lock_class = self._use_lua_lock and LuaLock or Lock + return lock_class(self, name, timeout=timeout, sleep=sleep, + blocking_timeout=blocking_timeout) def pubsub(self, **kwargs): """ diff --git a/redis/lock.py b/redis/lock.py index 5c2ba2b..8008d22 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -1,7 +1,7 @@ import time as mod_time import uuid from redis.exceptions import LockError, WatchError -from redis._compat import long +from redis._compat import b class Lock(object): @@ -69,7 +69,7 @@ class Lock(object): wait trying to acquire the lock. """ sleep = self.sleep - token = uuid.uuid1().hex + token = b(uuid.uuid1().hex) if blocking is None: blocking = self.blocking if blocking_timeout is None: @@ -90,12 +90,9 @@ class Lock(object): def do_acquire(self, token): if self.redis.setnx(self.name, token): if self.timeout: - if isinstance(self.timeout, (int, long)): - self.redis.expire(self.name, self.timeout) - else: - # convert float to milliseconds - timeout = int(self.timeout * 1000) - self.redis.pexpire(self.name, timeout) + # convert to milliseconds + timeout = int(self.timeout * 1000) + self.redis.pexpire(self.name, timeout) return True return False @@ -156,3 +153,92 @@ class Lock(object): # pexpire returns False if the key doesn't exist raise LockError("Cannot extend a lock that's no longer owned") return True + + +class LuaLock(Lock): + """ + A lock implementation that uses Lua scripts rather than pipelines + and watches. + """ + lua_acquire = None + lua_release = None + lua_extend = None + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - timeout in milliseconds + # return 1 if lock was acquired, otherwise 0 + LUA_ACQUIRE_SCRIPT = """ + if redis.call('setnx', KEYS[1], ARGV[1]) == 1 then + if ARGV[2] ~= '' then + redis.call('pexpire', KEYS[1], ARGV[2]) + end + return 1 + end + return 0 + """ + + # KEYS[1] - lock name + # ARGS[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - lock name + # ARGS[1] - token + # ARGS[2] - additional milliseconds + # return 1 if the locks time was extended, otherwise 0 + LUA_EXTEND_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + redis.call('pexpire', KEYS[1], expiration + ARGV[2]) + return 1 + """ + + def __init__(self, *args, **kwargs): + super(LuaLock, self).__init__(*args, **kwargs) + LuaLock.register_scripts(self.redis) + + @classmethod + def register_scripts(cls, redis): + if cls.lua_acquire is None: + cls.lua_acquire = redis.register_script(cls.LUA_ACQUIRE_SCRIPT) + if cls.lua_release is None: + cls.lua_release = redis.register_script(cls.LUA_RELEASE_SCRIPT) + if cls.lua_extend is None: + cls.lua_extend = redis.register_script(cls.LUA_EXTEND_SCRIPT) + + def do_acquire(self, token): + timeout = self.timeout and int(self.timeout * 1000) or '' + return bool(self.lua_acquire(keys=[self.name], + args=[token, timeout], + client=self.redis)) + + def do_release(self): + if not bool(self.lua_release(keys=[self.name], + args=[self.token], + client=self.redis)): + raise LockError("Cannot release a lock that's no longer owned") + + def do_extend(self, additional_time): + additional_time = int(additional_time * 1000) + if not bool(self.lua_extend(keys=[self.name], + args=[self.token, additional_time], + client=self.redis)): + raise LockError("Cannot extend a lock that's no longer owned") + return True diff --git a/tests/test_lock.py b/tests/test_lock.py index 304dbe1..028f9a6 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -2,13 +2,19 @@ from __future__ import with_statement import pytest import time -from redis.lock import LockError +from redis.exceptions import LockError, ResponseError +from redis.lock import Lock, LuaLock class TestLock(object): + lock_class = Lock + + def get_lock(self, redis, *args, **kwargs): + kwargs['lock_class'] = self.lock_class + return redis.lock(*args, **kwargs) def test_lock(self, sr): - lock = sr.lock('foo') + lock = self.get_lock(sr, 'foo') assert lock.acquire(blocking=False) assert sr.get('foo') == lock.token assert sr.ttl('foo') == -1 @@ -16,8 +22,8 @@ class TestLock(object): assert sr.get('foo') is None def test_competing_locks(self, sr): - lock1 = sr.lock('foo') - lock2 = sr.lock('foo') + lock1 = self.get_lock(sr, 'foo') + lock2 = self.get_lock(sr, 'foo') assert lock1.acquire(blocking=False) assert not lock2.acquire(blocking=False) lock1.release() @@ -26,43 +32,45 @@ class TestLock(object): lock2.release() def test_timeout(self, sr): - lock = sr.lock('foo', timeout=10) + lock = self.get_lock(sr, 'foo', timeout=10) assert lock.acquire(blocking=False) - assert 0 < sr.ttl('foo') <= 10 + assert 8 < sr.ttl('foo') <= 10 lock.release() def test_float_timeout(self, sr): - lock = sr.lock('foo', timeout=9.5) + lock = self.get_lock(sr, 'foo', timeout=9.5) assert lock.acquire(blocking=False) - assert 0 < sr.pttl('foo') <= 9500 + assert 8 < sr.pttl('foo') <= 9500 lock.release() def test_blocking_timeout(self, sr): - lock1 = sr.lock('foo') + lock1 = self.get_lock(sr, 'foo') assert lock1.acquire(blocking=False) - lock2 = sr.lock('foo', blocking_timeout=0.2) + lock2 = self.get_lock(sr, 'foo', blocking_timeout=0.2) start = time.time() assert not lock2.acquire() assert (time.time() - start) > 0.2 lock1.release() def test_context_manager(self, sr): - with sr.lock('foo') as lock: + # blocking_timeout prevents a deadlock if the lock can't be acquired + # for some reason + with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock: assert sr.get('foo') == lock.token assert sr.get('foo') is None def test_high_sleep_raises_error(self, sr): "If sleep is higher than timeout, it should raise an error" with pytest.raises(LockError): - sr.lock('foo', timeout=1, sleep=2) + self.get_lock(sr, 'foo', timeout=1, sleep=2) def test_releasing_unlocked_lock_raises_error(self, sr): - lock = sr.lock('foo') + lock = self.get_lock(sr, 'foo') with pytest.raises(LockError): lock.release() def test_releasing_lock_no_longer_owned_raises_error(self, sr): - lock = sr.lock('foo') + lock = self.get_lock(sr, 'foo') lock.acquire(blocking=False) # manually change the token sr.set('foo', 'a') @@ -72,36 +80,88 @@ class TestLock(object): assert lock.token is None def test_extend_lock(self, sr): - lock = sr.lock('foo', timeout=10) - assert lock.acquire(blocking=True) - assert 0 < sr.pttl('foo') <= 10000 + lock = self.get_lock(sr, 'foo', timeout=10) + assert lock.acquire(blocking=False) + assert 8000 < sr.pttl('foo') <= 10000 assert lock.extend(10) - assert 10000 < sr.pttl('foo') < 20000 + assert 16000 < sr.pttl('foo') <= 20000 lock.release() def test_extend_lock_float(self, sr): - lock = sr.lock('foo', timeout=10.0) - assert lock.acquire(blocking=True) - assert 0 < sr.pttl('foo') <= 10000 + lock = self.get_lock(sr, 'foo', timeout=10.0) + assert lock.acquire(blocking=False) + assert 8000 < sr.pttl('foo') <= 10000 assert lock.extend(10.0) - assert 10000 < sr.pttl('foo') < 20000 + assert 16000 < sr.pttl('foo') <= 20000 lock.release() def test_extending_unlocked_lock_raises_error(self, sr): - lock = sr.lock('foo', timeout=10) + lock = self.get_lock(sr, 'foo', timeout=10) with pytest.raises(LockError): lock.extend(10) def test_extending_lock_with_no_timeout_raises_error(self, sr): - lock = sr.lock('foo') + lock = self.get_lock(sr, 'foo') assert lock.acquire(blocking=False) with pytest.raises(LockError): lock.extend(10) lock.release() def test_extending_lock_no_longer_owned_raises_error(self, sr): - lock = sr.lock('foo') + lock = self.get_lock(sr, 'foo') assert lock.acquire(blocking=False) sr.set('foo', 'a') with pytest.raises(LockError): lock.extend(10) + + +class TestLuaLock(TestLock): + lock_class = LuaLock + + +class TestLockClassSelection(object): + def test_lock_class_argument(self, sr): + lock = sr.lock('foo', lock_class=Lock) + assert type(lock) == Lock + lock = sr.lock('foo', lock_class=LuaLock) + assert type(lock) == LuaLock + + def test_cached_lualock_flag(self, sr): + try: + sr._use_lua_lock = True + lock = sr.lock('foo') + assert type(lock) == LuaLock + finally: + sr._use_lua_lock = None + + def test_cached_lock_flag(self, sr): + try: + sr._use_lua_lock = False + lock = sr.lock('foo') + assert type(lock) == Lock + finally: + sr._use_lua_lock = None + + def test_lua_compatible_server(self, sr, monkeypatch): + @classmethod + def mock_register(cls, redis): + return + monkeypatch.setattr(LuaLock, 'register_scripts', mock_register) + try: + lock = sr.lock('foo') + assert type(lock) == LuaLock + assert sr._use_lua_lock is True + finally: + sr._use_lua_lock = None + + def test_lua_unavailable(self, sr, monkeypatch): + @classmethod + def mock_register(cls, redis): + raise ResponseError() + monkeypatch.setattr(LuaLock, 'register_scripts', mock_register) + try: + lock = sr.lock('foo') + assert type(lock) == Lock + assert sr._use_lua_lock is False + finally: + sr._use_lua_lock = None |