summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnas <anas.el.amraoui@live.com>2022-06-01 14:59:02 +0300
committerGitHub <noreply@github.com>2022-06-01 14:59:02 +0300
commitfa0be7671de6be85f859cbb57a31531b2482c9e1 (patch)
tree5470f36a4094e413f88480da35a3980ba73999bf
parent05fc203f68c24fbd54c7b338b4610fa62972c326 (diff)
downloadredis-py-fa0be7671de6be85f859cbb57a31531b2482c9e1.tar.gz
Made sync lock consistent and added types to it (#2137)
* Made sync lock consistent and added types to it * Made linters happy * Fixed cluster client lock signature Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
-rwxr-xr-xredis/client.py8
-rw-r--r--redis/cluster.py8
-rw-r--r--redis/lock.py68
-rw-r--r--redis/typing.py1
-rw-r--r--tests/test_lock.py20
5 files changed, 80 insertions, 25 deletions
diff --git a/redis/client.py b/redis/client.py
index 58668ee..fcc2758 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -1098,6 +1098,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands):
name,
timeout=None,
sleep=0.1,
+ blocking=True,
blocking_timeout=None,
lock_class=None,
thread_local=True,
@@ -1113,6 +1114,12 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands):
when the lock is in blocking mode and another client is currently
holding the lock.
+ ``blocking`` indicates whether calling ``acquire`` should block until
+ the lock has been acquired or to fail immediately, causing ``acquire``
+ to return False and the lock not being acquired. Defaults to True.
+ Note this value can be overridden by passing a ``blocking``
+ argument to ``acquire``.
+
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
@@ -1155,6 +1162,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands):
name,
timeout=timeout,
sleep=sleep,
+ blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
diff --git a/redis/cluster.py b/redis/cluster.py
index a88792b..8e4c654 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -766,6 +766,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
name,
timeout=None,
sleep=0.1,
+ blocking=True,
blocking_timeout=None,
lock_class=None,
thread_local=True,
@@ -781,6 +782,12 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
when the lock is in blocking mode and another client is currently
holding the lock.
+ ``blocking`` indicates whether calling ``acquire`` should block until
+ the lock has been acquired or to fail immediately, causing ``acquire``
+ to return False and the lock not being acquired. Defaults to True.
+ Note this value can be overridden by passing a ``blocking``
+ argument to ``acquire``.
+
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
@@ -823,6 +830,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
name,
timeout=timeout,
sleep=sleep,
+ blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
diff --git a/redis/lock.py b/redis/lock.py
index 74e769b..c509f7d 100644
--- a/redis/lock.py
+++ b/redis/lock.py
@@ -1,9 +1,11 @@
import threading
import time as mod_time
import uuid
-from types import SimpleNamespace
+from types import SimpleNamespace, TracebackType
+from typing import Optional, Type
from redis.exceptions import LockError, LockNotOwnedError
+from redis.typing import Number
class Lock:
@@ -74,12 +76,13 @@ class Lock:
def __init__(
self,
redis,
- name,
- timeout=None,
- sleep=0.1,
- blocking=True,
- blocking_timeout=None,
- thread_local=True,
+ name: str,
+ *,
+ timeout: Optional[Number] = None,
+ sleep: Number = 0.1,
+ blocking: bool = True,
+ blocking_timeout: Optional[Number] = None,
+ thread_local: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
@@ -142,7 +145,7 @@ class Lock:
self.local.token = None
self.register_scripts()
- def register_scripts(self):
+ def register_scripts(self) -> None:
cls = self.__class__
client = self.redis
if cls.lua_release is None:
@@ -152,15 +155,27 @@ class Lock:
if cls.lua_reacquire is None:
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
- def __enter__(self):
+ def __enter__(self) -> "Lock":
if self.acquire():
return self
raise LockError("Unable to acquire lock within the time specified")
- def __exit__(self, exc_type, exc_value, traceback):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
self.release()
- def acquire(self, blocking=None, blocking_timeout=None, token=None):
+ def acquire(
+ self,
+ *,
+ sleep: Optional[Number] = None,
+ blocking: Optional[bool] = None,
+ blocking_timeout: Optional[Number] = None,
+ token: Optional[str] = None,
+ ):
"""
Use Redis to hold a shared, distributed lock named ``name``.
Returns True once the lock is acquired.
@@ -176,7 +191,8 @@ class Lock:
object with the default encoding. If a token isn't specified, a UUID
will be generated.
"""
- sleep = self.sleep
+ if sleep is None:
+ sleep = self.sleep
if token is None:
token = uuid.uuid1().hex.encode()
else:
@@ -200,7 +216,7 @@ class Lock:
return False
mod_time.sleep(sleep)
- def do_acquire(self, token):
+ def do_acquire(self, token: str) -> bool:
if self.timeout:
# convert to milliseconds
timeout = int(self.timeout * 1000)
@@ -210,13 +226,13 @@ class Lock:
return True
return False
- def locked(self):
+ def locked(self) -> bool:
"""
Returns True if this key is locked by any process, otherwise False.
"""
return self.redis.get(self.name) is not None
- def owned(self):
+ def owned(self) -> bool:
"""
Returns True if this key is locked by this lock, otherwise False.
"""
@@ -228,21 +244,23 @@ class Lock:
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token
- def release(self):
- "Releases the already acquired lock"
+ def release(self) -> None:
+ """
+ Releases the already acquired lock
+ """
expected_token = self.local.token
if expected_token is None:
raise LockError("Cannot release an unlocked lock")
self.local.token = None
self.do_release(expected_token)
- def do_release(self, expected_token):
+ def do_release(self, expected_token: str) -> None:
if not bool(
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
):
raise LockNotOwnedError("Cannot release a lock" " that's no longer owned")
- def extend(self, additional_time, replace_ttl=False):
+ def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
"""
Adds more time to an already acquired lock.
@@ -259,19 +277,19 @@ class Lock:
raise LockError("Cannot extend a lock with no timeout")
return self.do_extend(additional_time, replace_ttl)
- def do_extend(self, additional_time, replace_ttl):
+ def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
additional_time = int(additional_time * 1000)
if not bool(
self.lua_extend(
keys=[self.name],
- args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
+ args=[self.local.token, additional_time, "1" if replace_ttl else "0"],
client=self.redis,
)
):
- raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned")
+ raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
return True
- def reacquire(self):
+ def reacquire(self) -> bool:
"""
Resets a TTL of an already acquired lock back to a timeout value.
"""
@@ -281,12 +299,12 @@ class Lock:
raise LockError("Cannot reacquire a lock with no timeout")
return self.do_reacquire()
- def do_reacquire(self):
+ def do_reacquire(self) -> bool:
timeout = int(self.timeout * 1000)
if not bool(
self.lua_reacquire(
keys=[self.name], args=[self.local.token, timeout], client=self.redis
)
):
- raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned")
+ raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
return True
diff --git a/redis/typing.py b/redis/typing.py
index 6748612..b572b0c 100644
--- a/redis/typing.py
+++ b/redis/typing.py
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from redis.connection import ConnectionPool, Encoder
+Number = Union[int, float]
EncodedT = Union[bytes, memoryview]
DecodedT = Union[str, int, float]
EncodableT = Union[EncodedT, DecodedT]
diff --git a/tests/test_lock.py b/tests/test_lock.py
index 0a63f1e..10ad7e1 100644
--- a/tests/test_lock.py
+++ b/tests/test_lock.py
@@ -116,6 +116,16 @@ class TestLock:
assert r.get("foo") == lock.local.token
assert r.get("foo") is None
+ def test_context_manager_blocking_timeout(self, r):
+ with self.get_lock(r, "foo", blocking=False):
+ bt = 0.4
+ sleep = 0.05
+ lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt)
+ start = time.monotonic()
+ assert not lock2.acquire()
+ # The elapsed duration should be less than the total blocking_timeout
+ assert bt > (time.monotonic() - start) > bt - sleep
+
def test_context_manager_raises_when_locked_not_acquired(self, r):
r.set("foo", "bar")
with pytest.raises(LockError):
@@ -221,6 +231,16 @@ class TestLock:
with pytest.raises(LockNotOwnedError):
lock.reacquire()
+ def test_context_manager_reacquiring_lock_with_no_timeout_raises_error(self, r):
+ with self.get_lock(r, "foo", timeout=None, blocking=False) as lock:
+ with pytest.raises(LockError):
+ lock.reacquire()
+
+ def test_context_manager_reacquiring_lock_no_longer_owned_raises_error(self, r):
+ with pytest.raises(LockNotOwnedError):
+ with self.get_lock(r, "foo", timeout=10, blocking=False):
+ r.set("foo", "a")
+
class TestLockClassSelection:
def test_lock_class_argument(self, r):