diff options
author | Anas <anas.el.amraoui@live.com> | 2022-06-01 14:59:02 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-01 14:59:02 +0300 |
commit | fa0be7671de6be85f859cbb57a31531b2482c9e1 (patch) | |
tree | 5470f36a4094e413f88480da35a3980ba73999bf | |
parent | 05fc203f68c24fbd54c7b338b4610fa62972c326 (diff) | |
download | redis-py-fa0be7671de6be85f859cbb57a31531b2482c9e1.tar.gz |
Made sync lock consistent and added types to it (#2137)
* Made sync lock consistent and added types to it
* Made linters happy
* Fixed cluster client lock signature
Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
-rwxr-xr-x | redis/client.py | 8 | ||||
-rw-r--r-- | redis/cluster.py | 8 | ||||
-rw-r--r-- | redis/lock.py | 68 | ||||
-rw-r--r-- | redis/typing.py | 1 | ||||
-rw-r--r-- | tests/test_lock.py | 20 |
5 files changed, 80 insertions, 25 deletions
diff --git a/redis/client.py b/redis/client.py index 58668ee..fcc2758 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1098,6 +1098,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): name, timeout=None, sleep=0.1, + blocking=True, blocking_timeout=None, lock_class=None, thread_local=True, @@ -1113,6 +1114,12 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): when the lock is in blocking mode and another client is currently holding the lock. + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + ``blocking_timeout`` indicates the maximum amount of time in seconds to spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a @@ -1155,6 +1162,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): name, timeout=timeout, sleep=sleep, + blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, ) diff --git a/redis/cluster.py b/redis/cluster.py index a88792b..8e4c654 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -766,6 +766,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): name, timeout=None, sleep=0.1, + blocking=True, blocking_timeout=None, lock_class=None, thread_local=True, @@ -781,6 +782,12 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): when the lock is in blocking mode and another client is currently holding the lock. + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + ``blocking_timeout`` indicates the maximum amount of time in seconds to spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a @@ -823,6 +830,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): name, timeout=timeout, sleep=sleep, + blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, ) diff --git a/redis/lock.py b/redis/lock.py index 74e769b..c509f7d 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -1,9 +1,11 @@ import threading import time as mod_time import uuid -from types import SimpleNamespace +from types import SimpleNamespace, TracebackType +from typing import Optional, Type from redis.exceptions import LockError, LockNotOwnedError +from redis.typing import Number class Lock: @@ -74,12 +76,13 @@ class Lock: def __init__( self, redis, - name, - timeout=None, - sleep=0.1, - blocking=True, - blocking_timeout=None, - thread_local=True, + name: str, + *, + timeout: Optional[Number] = None, + sleep: Number = 0.1, + blocking: bool = True, + blocking_timeout: Optional[Number] = None, + thread_local: bool = True, ): """ Create a new Lock instance named ``name`` using the Redis client @@ -142,7 +145,7 @@ class Lock: self.local.token = None self.register_scripts() - def register_scripts(self): + def register_scripts(self) -> None: cls = self.__class__ client = self.redis if cls.lua_release is None: @@ -152,15 +155,27 @@ class Lock: if cls.lua_reacquire is None: cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) - def __enter__(self): + def __enter__(self) -> "Lock": if self.acquire(): return self raise LockError("Unable to acquire lock within the time specified") - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: self.release() - def acquire(self, blocking=None, blocking_timeout=None, token=None): + def acquire( + self, + *, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + token: Optional[str] = None, + ): """ Use Redis to hold a shared, distributed lock named ``name``. Returns True once the lock is acquired. @@ -176,7 +191,8 @@ class Lock: object with the default encoding. If a token isn't specified, a UUID will be generated. """ - sleep = self.sleep + if sleep is None: + sleep = self.sleep if token is None: token = uuid.uuid1().hex.encode() else: @@ -200,7 +216,7 @@ class Lock: return False mod_time.sleep(sleep) - def do_acquire(self, token): + def do_acquire(self, token: str) -> bool: if self.timeout: # convert to milliseconds timeout = int(self.timeout * 1000) @@ -210,13 +226,13 @@ class Lock: return True return False - def locked(self): + def locked(self) -> bool: """ Returns True if this key is locked by any process, otherwise False. """ return self.redis.get(self.name) is not None - def owned(self): + def owned(self) -> bool: """ Returns True if this key is locked by this lock, otherwise False. """ @@ -228,21 +244,23 @@ class Lock: stored_token = encoder.encode(stored_token) return self.local.token is not None and stored_token == self.local.token - def release(self): - "Releases the already acquired lock" + def release(self) -> None: + """ + Releases the already acquired lock + """ expected_token = self.local.token if expected_token is None: raise LockError("Cannot release an unlocked lock") self.local.token = None self.do_release(expected_token) - def do_release(self, expected_token): + def do_release(self, expected_token: str) -> None: if not bool( self.lua_release(keys=[self.name], args=[expected_token], client=self.redis) ): raise LockNotOwnedError("Cannot release a lock" " that's no longer owned") - def extend(self, additional_time, replace_ttl=False): + def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: """ Adds more time to an already acquired lock. @@ -259,19 +277,19 @@ class Lock: raise LockError("Cannot extend a lock with no timeout") return self.do_extend(additional_time, replace_ttl) - def do_extend(self, additional_time, replace_ttl): + def do_extend(self, additional_time: int, replace_ttl: bool) -> bool: additional_time = int(additional_time * 1000) if not bool( self.lua_extend( keys=[self.name], - args=[self.local.token, additional_time, replace_ttl and "1" or "0"], + args=[self.local.token, additional_time, "1" if replace_ttl else "0"], client=self.redis, ) ): - raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned") + raise LockNotOwnedError("Cannot extend a lock that's no longer owned") return True - def reacquire(self): + def reacquire(self) -> bool: """ Resets a TTL of an already acquired lock back to a timeout value. """ @@ -281,12 +299,12 @@ class Lock: raise LockError("Cannot reacquire a lock with no timeout") return self.do_reacquire() - def do_reacquire(self): + def do_reacquire(self) -> bool: timeout = int(self.timeout * 1000) if not bool( self.lua_reacquire( keys=[self.name], args=[self.local.token, timeout], client=self.redis ) ): - raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned") + raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned") return True diff --git a/redis/typing.py b/redis/typing.py index 6748612..b572b0c 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from redis.connection import ConnectionPool, Encoder +Number = Union[int, float] EncodedT = Union[bytes, memoryview] DecodedT = Union[str, int, float] EncodableT = Union[EncodedT, DecodedT] diff --git a/tests/test_lock.py b/tests/test_lock.py index 0a63f1e..10ad7e1 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -116,6 +116,16 @@ class TestLock: assert r.get("foo") == lock.local.token assert r.get("foo") is None + def test_context_manager_blocking_timeout(self, r): + with self.get_lock(r, "foo", blocking=False): + bt = 0.4 + sleep = 0.05 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = time.monotonic() + assert not lock2.acquire() + # The elapsed duration should be less than the total blocking_timeout + assert bt > (time.monotonic() - start) > bt - sleep + def test_context_manager_raises_when_locked_not_acquired(self, r): r.set("foo", "bar") with pytest.raises(LockError): @@ -221,6 +231,16 @@ class TestLock: with pytest.raises(LockNotOwnedError): lock.reacquire() + def test_context_manager_reacquiring_lock_with_no_timeout_raises_error(self, r): + with self.get_lock(r, "foo", timeout=None, blocking=False) as lock: + with pytest.raises(LockError): + lock.reacquire() + + def test_context_manager_reacquiring_lock_no_longer_owned_raises_error(self, r): + with pytest.raises(LockNotOwnedError): + with self.get_lock(r, "foo", timeout=10, blocking=False): + r.set("foo", "a") + class TestLockClassSelection: def test_lock_class_argument(self, r): |