From 772079fabd7453edf3788d0c31b9caf21ff5deca Mon Sep 17 00:00:00 2001 From: Milhan Date: Wed, 9 Nov 2022 21:22:05 +0900 Subject: Enable AsyncIO cluster mode lock (#2446) Co-authored-by: Chayim --- CHANGES | 1 + redis/asyncio/cluster.py | 67 +++++++++++++++++++++++++++++++++++++++++ redis/asyncio/lock.py | 16 +++++++--- redis/commands/core.py | 12 ++++++-- tests/test_asyncio/test_lock.py | 1 - 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/CHANGES b/CHANGES index 80423cc..7bdfacf 100644 --- a/CHANGES +++ b/CHANGES @@ -24,6 +24,7 @@ * ClusterPipeline Doesn't Handle ConnectionError for Dead Hosts (#2225) * Remove compatibility code for old versions of Hiredis, drop Packaging dependency * The `deprecated` library is no longer a dependency + * Enable Lock for asyncio cluster mode * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 8d34b9a..8abb072 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -24,6 +24,7 @@ from redis.asyncio.connection import ( SSLConnection, parse_url, ) +from redis.asyncio.lock import Lock from redis.asyncio.parser import CommandsParser from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( @@ -764,6 +765,72 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand return ClusterPipeline(self) + def lock( + self, + name: KeyT, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking_timeout: Optional[float] = None, + lock_class: Optional[Type[Lock]] = None, + thread_local: bool = True, + ) -> Lock: + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. Note that as + of redis-py 3.0, the only lock class we implement is ``Lock`` (which is + a Lua-based lock). So, it's unlikely you'll need this parameter, unless + you have created your own custom lock class. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + class ClusterNode: """ diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py index 8ede59b..7f45c8b 100644 --- a/redis/asyncio/lock.py +++ b/redis/asyncio/lock.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Awaitable, Optional, Union from redis.exceptions import LockError, LockNotOwnedError if TYPE_CHECKING: - from redis.asyncio import Redis + from redis.asyncio import Redis, RedisCluster class Lock: @@ -77,7 +77,7 @@ class Lock: def __init__( self, - redis: "Redis", + redis: Union["Redis", "RedisCluster"], name: Union[str, bytes, memoryview], timeout: Optional[float] = None, sleep: float = 0.1, @@ -189,7 +189,11 @@ class Lock: if token is None: token = uuid.uuid1().hex.encode() else: - encoder = self.redis.connection_pool.get_encoder() + try: + encoder = self.redis.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = self.redis.get_encoder() token = encoder.encode(token) if blocking is None: blocking = self.blocking @@ -233,7 +237,11 @@ class Lock: # need to always compare bytes to bytes # TODO: this can be simplified when the context manager is finished if stored_token and not isinstance(stored_token, bytes): - encoder = self.redis.connection_pool.get_encoder() + try: + encoder = self.redis.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = self.redis.get_encoder() stored_token = encoder.encode(stored_token) return self.local.token is not None and stored_token == self.local.token diff --git a/redis/commands/core.py b/redis/commands/core.py index c245a7a..3be2823 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -4930,7 +4930,11 @@ class Script: if isinstance(script, str): # We need the encoding from the client in order to generate an # accurate byte representation of the script - encoder = registered_client.connection_pool.get_encoder() + try: + encoder = registered_client.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = registered_client.get_encoder() script = encoder.encode(script) self.sha = hashlib.sha1(script).hexdigest() @@ -4975,7 +4979,11 @@ class AsyncScript: if isinstance(script, str): # We need the encoding from the client in order to generate an # accurate byte representation of the script - encoder = registered_client.connection_pool.get_encoder() + try: + encoder = registered_client.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = registered_client.get_encoder() script = encoder.encode(script) self.sha = hashlib.sha1(script).hexdigest() diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 09aec75..56387fa 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -7,7 +7,6 @@ from redis.asyncio.lock import Lock from redis.exceptions import LockError, LockNotOwnedError -@pytest.mark.onlynoncluster class TestLock: @pytest_asyncio.fixture() async def r_decoded(self, create_redis): -- cgit v1.2.1