summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorUtkarsh Gupta <utkarshgupta137@gmail.com>2022-05-30 21:45:45 +0530
committerGitHub <noreply@github.com>2022-05-30 19:15:45 +0300
commitbac33d4a92892ca7982b461347151bff5a661f0d (patch)
tree976d5dafcc2b3a1c4e129e1da439f1b7bdacacbd
parentc54dfa49dda6a7b3389dc230726293af3ffc68a3 (diff)
downloadredis-py-bac33d4a92892ca7982b461347151bff5a661f0d.tar.gz
async_cluster: add pipeline support (#2199)
Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
-rw-r--r--docs/connections.rst5
-rw-r--r--redis/asyncio/cluster.py328
-rw-r--r--redis/cluster.py114
-rw-r--r--tests/test_asyncio/test_cluster.py254
-rw-r--r--tests/test_asyncio/test_pipeline.py7
5 files changed, 652 insertions, 56 deletions
diff --git a/docs/connections.rst b/docs/connections.rst
index e4b82cd..b481689 100644
--- a/docs/connections.rst
+++ b/docs/connections.rst
@@ -76,6 +76,11 @@ ClusterNode (Async)
.. autoclass:: redis.asyncio.cluster.ClusterNode
:members:
+ClusterPipeline (Async)
+===================
+.. autoclass:: redis.asyncio.cluster.ClusterPipeline
+ :members:
+
Connection
**********
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index 39aa536..3405a49 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -3,19 +3,32 @@ import collections
import random
import socket
import warnings
-from typing import Any, Deque, Dict, Generator, List, Optional, Type, TypeVar, Union
+from typing import (
+ Any,
+ Deque,
+ Dict,
+ Generator,
+ List,
+ Mapping,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+)
from redis.asyncio.client import ResponseCallbackT
from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url
from redis.asyncio.parser import CommandsParser
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
from redis.cluster import (
+ PIPELINE_BLOCKED_COMMANDS,
PRIMARY,
READ_COMMANDS,
REPLICA,
SLOT_ID,
AbstractRedisCluster,
LoadBalancer,
+ block_pipeline_command,
get_node_name,
parse_cluster_slots,
)
@@ -37,8 +50,8 @@ from redis.exceptions import (
TimeoutError,
TryAgainError,
)
-from redis.typing import EncodableT, KeyT
-from redis.utils import dict_merge, str_if_bytes
+from redis.typing import AnyKeyT, EncodableT, KeyT
+from redis.utils import dict_merge, safe_str, str_if_bytes
TargetNodesT = TypeVar(
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
@@ -719,6 +732,24 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
raise ClusterError("TTL exhausted.")
+ def pipeline(
+ self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
+ ) -> "ClusterPipeline":
+ """
+ Create & return a new :class:`~.ClusterPipeline` object.
+
+ Cluster implementation of pipeline does not support transaction or shard_hint.
+
+ :raises RedisClusterException: if transaction or shard_hint are truthy values
+ """
+ if shard_hint:
+ raise RedisClusterException("shard_hint is deprecated in cluster mode")
+
+ if transaction:
+ raise RedisClusterException("transaction is deprecated in cluster mode")
+
+ return ClusterPipeline(self)
+
class ClusterNode:
"""
@@ -729,6 +760,7 @@ class ClusterNode:
"""
__slots__ = (
+ "_command_stack",
"_connections",
"_free",
"connection_class",
@@ -768,6 +800,7 @@ class ClusterNode:
self._connections: List[Connection] = []
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
+ self._command_stack: List["PipelineCommand"] = []
def __repr__(self) -> str:
return (
@@ -806,27 +839,26 @@ class ClusterNode:
if exc:
raise exc
- async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
- # Acquire connection
- connection = None
+ def acquire_connection(self) -> Connection:
if self._free:
for _ in range(len(self._free)):
connection = self._free.popleft()
if connection.is_connected:
- break
+ return connection
self._free.append(connection)
- else:
- connection = self._free.popleft()
+
+ return self._free.popleft()
else:
if len(self._connections) < self.max_connections:
connection = self.connection_class(**self.connection_kwargs)
self._connections.append(connection)
+ return connection
else:
raise ConnectionError("Too many connections")
- # Execute command
- command = connection.pack_command(*args)
- await connection.send_packed_command(command, False)
+ async def parse_response(
+ self, connection: Connection, command: str, **kwargs: Any
+ ) -> Any:
try:
if NEVER_DECODE in kwargs:
response = await connection.read_response_without_lock(
@@ -838,16 +870,49 @@ class ClusterNode:
if EMPTY_RESPONSE in kwargs:
return kwargs[EMPTY_RESPONSE]
raise
- finally:
- # Release connection
- self._free.append(connection)
# Return response
try:
- return self.response_callbacks[args[0]](response, **kwargs)
+ return self.response_callbacks[command](response, **kwargs)
except KeyError:
return response
+ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
+ # Acquire connection
+ connection = self.acquire_connection()
+
+ # Execute command
+ await connection.send_packed_command(connection.pack_command(*args), False)
+
+ # Read response
+ try:
+ return await self.parse_response(connection, args[0], **kwargs)
+ finally:
+ # Release connection
+ self._free.append(connection)
+
+ async def execute_pipeline(self) -> None:
+ # Acquire connection
+ connection = self.acquire_connection()
+
+ # Execute command
+ await connection.send_packed_command(
+ connection.pack_commands(cmd.args for cmd in self._command_stack), False
+ )
+
+ # Read responses
+ try:
+ for cmd in self._command_stack:
+ try:
+ cmd.result = await self.parse_response(
+ connection, cmd.args[0], **cmd.kwargs
+ )
+ except Exception as e:
+ cmd.result = e
+ finally:
+ # Release connection
+ self._free.append(connection)
+
class NodesManager:
__slots__ = (
@@ -1131,3 +1196,234 @@ class NodesManager:
for node in getattr(self, attr).values()
)
)
+
+
+class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
+ """
+ Create a new ClusterPipeline object.
+
+ Usage::
+
+ result = await (
+ rc.pipeline()
+ .set("A", 1)
+ .get("A")
+ .hset("K", "F", "V")
+ .hgetall("K")
+ .mset_nonatomic({"A": 2, "B": 3})
+ .get("A")
+ .get("B")
+ .delete("A", "B", "K")
+ .execute()
+ )
+ # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
+
+ Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
+ are split across multiple nodes, you'll get multiple results for them in the array.
+
+ Retryable errors:
+ - :class:`~.ClusterDownError`
+ - :class:`~.ConnectionError`
+ - :class:`~.TimeoutError`
+
+ Redirection errors:
+ - :class:`~.TryAgainError`
+ - :class:`~.MovedError`
+ - :class:`~.AskError`
+
+ :param client:
+ | Existing :class:`~.RedisCluster` client
+ """
+
+ __slots__ = ("_command_stack", "_client")
+
+ def __init__(self, client: RedisCluster) -> None:
+ self._client = client
+
+ self._command_stack: List["PipelineCommand"] = []
+
+ async def initialize(self) -> "ClusterPipeline":
+ if self._client._initialize:
+ await self._client.initialize()
+ self._command_stack = []
+ return self
+
+ async def __aenter__(self) -> "ClusterPipeline":
+ return await self.initialize()
+
+ async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
+ self._command_stack = []
+
+ def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
+ return self.initialize().__await__()
+
+ def __bool__(self) -> bool:
+ return bool(self._command_stack)
+
+ def __len__(self) -> int:
+ return len(self._command_stack)
+
+ def execute_command(
+ self, *args: Union[KeyT, EncodableT], **kwargs: Any
+ ) -> "ClusterPipeline":
+ """
+ Append a raw command to the pipeline.
+
+ :param args:
+ | Raw command args
+ :param kwargs:
+
+ - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
+ or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
+ - Rest of the kwargs are passed to the Redis connection
+ """
+ self._command_stack.append(
+ PipelineCommand(len(self._command_stack), *args, **kwargs)
+ )
+ return self
+
+ async def execute(
+ self, raise_on_error: bool = True, allow_redirections: bool = True
+ ) -> List[Any]:
+ """
+ Execute the pipeline.
+
+ It will retry the commands as specified by :attr:`cluster_error_retry_attempts`
+ & then raise an exception.
+
+ :param raise_on_error:
+ | Raise the first error if there are any errors
+ :param allow_redirections:
+ | Whether to retry each failed command individually in case of redirection
+ errors
+
+ :raises RedisClusterException: if target_nodes is not provided & the command
+ can't be mapped to a slot
+ """
+ if not self._command_stack:
+ return []
+
+ try:
+ for _ in range(self._client.cluster_error_retry_attempts):
+ if self._client._initialize:
+ await self._client.initialize()
+
+ try:
+ return await self._execute(
+ self._command_stack,
+ raise_on_error=raise_on_error,
+ allow_redirections=allow_redirections,
+ )
+ except BaseException as e:
+ if type(e) in self.__class__.ERRORS_ALLOW_RETRY:
+ # Try again with the new cluster setup.
+ exception = e
+ await self._client.close()
+ await asyncio.sleep(0.25)
+ else:
+ # All other errors should be raised.
+ raise e
+
+ # If it fails the configured number of times then raise an exception
+ raise exception
+ finally:
+ self._command_stack = []
+
+ async def _execute(
+ self,
+ stack: List["PipelineCommand"],
+ raise_on_error: bool = True,
+ allow_redirections: bool = True,
+ ) -> List[Any]:
+ client = self._client
+ nodes = {}
+ for cmd in stack:
+ if not cmd.result or isinstance(cmd.result, Exception):
+ target_nodes = await client._determine_nodes(*cmd.args)
+ if not target_nodes:
+ raise RedisClusterException(
+ f"No targets were found to execute {cmd.args} command on"
+ )
+ if len(target_nodes) > 1:
+ raise RedisClusterException(
+ f"Too many targets for command {cmd.args}"
+ )
+
+ node = target_nodes[0]
+ if node.name not in nodes:
+ nodes[node.name] = node
+ node._command_stack = []
+ node._command_stack.append(cmd)
+
+ await asyncio.gather(
+ *(asyncio.ensure_future(node.execute_pipeline()) for node in nodes.values())
+ )
+
+ if allow_redirections:
+ # send each errored command individually
+ for cmd in stack:
+ if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
+ try:
+ cmd.result = await client.execute_command(
+ *cmd.args, **cmd.kwargs
+ )
+ except Exception as e:
+ cmd.result = e
+
+ responses = [cmd.result for cmd in stack]
+
+ if raise_on_error:
+ for cmd in stack:
+ result = cmd.result
+ if isinstance(result, Exception):
+ command = " ".join(map(safe_str, cmd.args))
+ msg = (
+ f"Command # {cmd.position + 1} ({command}) of pipeline "
+ f"caused error: {result.args}"
+ )
+ result.args = (msg,) + result.args[1:]
+ raise result
+
+ return responses
+
+ def _split_command_across_slots(
+ self, command: str, *keys: KeyT
+ ) -> "ClusterPipeline":
+ for slot_keys in self._client._partition_keys_by_slot(keys).values():
+ self.execute_command(command, *slot_keys)
+
+ return self
+
+ def mset_nonatomic(
+ self, mapping: Mapping[AnyKeyT, EncodableT]
+ ) -> "ClusterPipeline":
+ encoder = self._client.encoder
+
+ slots_pairs = {}
+ for pair in mapping.items():
+ slot = key_slot(encoder.encode(pair[0]))
+ slots_pairs.setdefault(slot, []).extend(pair)
+
+ for pairs in slots_pairs.values():
+ self.execute_command("MSET", *pairs)
+
+ return self
+
+
+for command in PIPELINE_BLOCKED_COMMANDS:
+ command = command.replace(" ", "_").lower()
+ if command == "mset_nonatomic":
+ continue
+
+ setattr(ClusterPipeline, command, block_pipeline_command(command))
+
+
+class PipelineCommand:
+ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
+ self.args = args
+ self.kwargs = kwargs
+ self.position = position
+ self.result: Union[Any, Exception] = None
+
+ def __repr__(self) -> str:
+ return f"[{self.position}] {self.args} ({self.kwargs})"
diff --git a/redis/cluster.py b/redis/cluster.py
index 46a96a6..2e31063 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -6,7 +6,7 @@ import sys
import threading
import time
from collections import OrderedDict
-from typing import Any, Dict, Tuple
+from typing import Any, Callable, Dict, Tuple
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
from redis.commands import CommandsParser, RedisClusterCommands
@@ -2130,7 +2130,7 @@ class ClusterPipeline(RedisCluster):
return self.execute_command("DEL", names[0])
-def block_pipeline_command(func):
+def block_pipeline_command(name: str) -> Callable[..., Any]:
"""
Prints error because some pipelined commands should
be blocked when running in cluster-mode
@@ -2138,7 +2138,7 @@ def block_pipeline_command(func):
def inner(*args, **kwargs):
raise RedisClusterException(
- f"ERROR: Calling pipelined function {func.__name__} is blocked "
+ f"ERROR: Calling pipelined function {name} is blocked "
f"when running redis in cluster mode..."
)
@@ -2146,39 +2146,81 @@ def block_pipeline_command(func):
# Blocked pipeline commands
-ClusterPipeline.bitop = block_pipeline_command(RedisCluster.bitop)
-ClusterPipeline.brpoplpush = block_pipeline_command(RedisCluster.brpoplpush)
-ClusterPipeline.client_getname = block_pipeline_command(RedisCluster.client_getname)
-ClusterPipeline.client_list = block_pipeline_command(RedisCluster.client_list)
-ClusterPipeline.client_setname = block_pipeline_command(RedisCluster.client_setname)
-ClusterPipeline.config_set = block_pipeline_command(RedisCluster.config_set)
-ClusterPipeline.dbsize = block_pipeline_command(RedisCluster.dbsize)
-ClusterPipeline.flushall = block_pipeline_command(RedisCluster.flushall)
-ClusterPipeline.flushdb = block_pipeline_command(RedisCluster.flushdb)
-ClusterPipeline.keys = block_pipeline_command(RedisCluster.keys)
-ClusterPipeline.mget = block_pipeline_command(RedisCluster.mget)
-ClusterPipeline.move = block_pipeline_command(RedisCluster.move)
-ClusterPipeline.mset = block_pipeline_command(RedisCluster.mset)
-ClusterPipeline.msetnx = block_pipeline_command(RedisCluster.msetnx)
-ClusterPipeline.pfmerge = block_pipeline_command(RedisCluster.pfmerge)
-ClusterPipeline.pfcount = block_pipeline_command(RedisCluster.pfcount)
-ClusterPipeline.ping = block_pipeline_command(RedisCluster.ping)
-ClusterPipeline.publish = block_pipeline_command(RedisCluster.publish)
-ClusterPipeline.randomkey = block_pipeline_command(RedisCluster.randomkey)
-ClusterPipeline.rename = block_pipeline_command(RedisCluster.rename)
-ClusterPipeline.renamenx = block_pipeline_command(RedisCluster.renamenx)
-ClusterPipeline.rpoplpush = block_pipeline_command(RedisCluster.rpoplpush)
-ClusterPipeline.scan = block_pipeline_command(RedisCluster.scan)
-ClusterPipeline.sdiff = block_pipeline_command(RedisCluster.sdiff)
-ClusterPipeline.sdiffstore = block_pipeline_command(RedisCluster.sdiffstore)
-ClusterPipeline.sinter = block_pipeline_command(RedisCluster.sinter)
-ClusterPipeline.sinterstore = block_pipeline_command(RedisCluster.sinterstore)
-ClusterPipeline.smove = block_pipeline_command(RedisCluster.smove)
-ClusterPipeline.sort = block_pipeline_command(RedisCluster.sort)
-ClusterPipeline.sunion = block_pipeline_command(RedisCluster.sunion)
-ClusterPipeline.sunionstore = block_pipeline_command(RedisCluster.sunionstore)
-ClusterPipeline.readwrite = block_pipeline_command(RedisCluster.readwrite)
-ClusterPipeline.readonly = block_pipeline_command(RedisCluster.readonly)
+PIPELINE_BLOCKED_COMMANDS = (
+ "BGREWRITEAOF",
+ "BGSAVE",
+ "BITOP",
+ "BRPOPLPUSH",
+ "CLIENT GETNAME",
+ "CLIENT KILL",
+ "CLIENT LIST",
+ "CLIENT SETNAME",
+ "CLIENT",
+ "CONFIG GET",
+ "CONFIG RESETSTAT",
+ "CONFIG REWRITE",
+ "CONFIG SET",
+ "CONFIG",
+ "DBSIZE",
+ "ECHO",
+ "EVALSHA",
+ "FLUSHALL",
+ "FLUSHDB",
+ "INFO",
+ "KEYS",
+ "LASTSAVE",
+ "MGET",
+ "MGET NONATOMIC",
+ "MOVE",
+ "MSET",
+ "MSET NONATOMIC",
+ "MSETNX",
+ "PFCOUNT",
+ "PFMERGE",
+ "PING",
+ "PUBLISH",
+ "RANDOMKEY",
+ "READONLY",
+ "READWRITE",
+ "RENAME",
+ "RENAMENX",
+ "RPOPLPUSH",
+ "SAVE",
+ "SCAN",
+ "SCRIPT EXISTS",
+ "SCRIPT FLUSH",
+ "SCRIPT KILL",
+ "SCRIPT LOAD",
+ "SCRIPT",
+ "SDIFF",
+ "SDIFFSTORE",
+ "SENTINEL GET MASTER ADDR BY NAME",
+ "SENTINEL MASTER",
+ "SENTINEL MASTERS",
+ "SENTINEL MONITOR",
+ "SENTINEL REMOVE",
+ "SENTINEL SENTINELS",
+ "SENTINEL SET",
+ "SENTINEL SLAVES",
+ "SENTINEL",
+ "SHUTDOWN",
+ "SINTER",
+ "SINTERSTORE",
+ "SLAVEOF",
+ "SLOWLOG GET",
+ "SLOWLOG LEN",
+ "SLOWLOG RESET",
+ "SLOWLOG",
+ "SMOVE",
+ "SORT",
+ "SUNION",
+ "SUNIONSTORE",
+ "TIME",
+)
+for command in PIPELINE_BLOCKED_COMMANDS:
+ command = command.replace(" ", "_").lower()
+
+ setattr(ClusterPipeline, command, block_pipeline_command(command))
class PipelineCommand:
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 123adc8..0c676cb 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -19,7 +19,7 @@ from _pytest.fixtures import FixtureRequest, SubRequest
from redis.asyncio import Connection, RedisCluster
from redis.asyncio.cluster import ClusterNode, NodesManager
from redis.asyncio.parser import CommandsParser
-from redis.cluster import PRIMARY, REPLICA, get_node_name
+from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
from redis.exceptions import (
AskError,
@@ -129,6 +129,16 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode:
return node
+def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode:
+ connection = mock.AsyncMock()
+ connection.is_connected = True
+ connection.read_response_without_lock.side_effect = exc
+ while node._free:
+ node._free.pop()
+ node._free.append(connection)
+ return node
+
+
def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster:
for node in rc.get_nodes():
mock_node_resp(node, response)
@@ -2218,3 +2228,245 @@ class TestNodesManager:
async with RedisCluster(startup_nodes=[node_1, node_2]) as rc:
assert rc.get_node(host=default_host, port=7001) is not None
assert rc.get_node(host=default_host, port=7002) is not None
+
+
+@pytest.mark.onlycluster
+class TestClusterPipeline:
+ """Tests for the ClusterPipeline class."""
+
+ async def test_blocked_arguments(self, r: RedisCluster) -> None:
+ """Test handling for blocked pipeline arguments."""
+ with pytest.raises(RedisClusterException) as ex:
+ r.pipeline(transaction=True)
+
+ assert str(ex.value) == "transaction is deprecated in cluster mode"
+
+ with pytest.raises(RedisClusterException) as ex:
+ r.pipeline(shard_hint=True)
+
+ assert str(ex.value) == "shard_hint is deprecated in cluster mode"
+
+ async def test_blocked_methods(self, r: RedisCluster) -> None:
+ """Test handling for blocked pipeline commands."""
+ pipeline = r.pipeline()
+ for command in PIPELINE_BLOCKED_COMMANDS:
+ command = command.replace(" ", "_").lower()
+ if command == "mset_nonatomic":
+ continue
+
+ with pytest.raises(RedisClusterException) as exc:
+ getattr(pipeline, command)()
+
+ assert str(exc.value) == (
+ f"ERROR: Calling pipelined function {command} is blocked "
+ "when running redis in cluster mode..."
+ )
+
+ async def test_empty_stack(self, r: RedisCluster) -> None:
+ """If a pipeline is executed with no commands it should return a empty list."""
+ p = r.pipeline()
+ result = await p.execute()
+ assert result == []
+
+ async def test_redis_cluster_pipeline(self, r: RedisCluster) -> None:
+ """Test that we can use a pipeline with the RedisCluster class"""
+ result = await (
+ r.pipeline()
+ .set("A", 1)
+ .get("A")
+ .hset("K", "F", "V")
+ .hgetall("K")
+ .mset_nonatomic({"A": 2, "B": 3})
+ .get("A")
+ .get("B")
+ .delete("A", "B", "K")
+ .execute()
+ )
+ assert result == [True, b"1", 1, {b"F": b"V"}, True, True, b"2", b"3", 1, 1, 1]
+
+ async def test_multi_key_operation_with_a_single_slot(
+ self, r: RedisCluster
+ ) -> None:
+ """Test multi key operation with a single slot."""
+ pipe = r.pipeline()
+ pipe.set("a{foo}", 1)
+ pipe.set("b{foo}", 2)
+ pipe.set("c{foo}", 3)
+ pipe.get("a{foo}")
+ pipe.get("b{foo}")
+ pipe.get("c{foo}")
+
+ res = await pipe.execute()
+ assert res == [True, True, True, b"1", b"2", b"3"]
+
+ async def test_multi_key_operation_with_multi_slots(self, r: RedisCluster) -> None:
+ """Test multi key operation with more than one slot."""
+ pipe = r.pipeline()
+ pipe.set("a{foo}", 1)
+ pipe.set("b{foo}", 2)
+ pipe.set("c{foo}", 3)
+ pipe.set("bar", 4)
+ pipe.set("bazz", 5)
+ pipe.get("a{foo}")
+ pipe.get("b{foo}")
+ pipe.get("c{foo}")
+ pipe.get("bar")
+ pipe.get("bazz")
+ res = await pipe.execute()
+ assert res == [True, True, True, True, True, b"1", b"2", b"3", b"4", b"5"]
+
+ async def test_cluster_down_error(self, r: RedisCluster) -> None:
+ """
+ Test that the pipeline retries cluster_error_retry_attempts times before raising
+ an error.
+ """
+ key = "foo"
+ node = r.get_node_from_key(key, False)
+
+ parse_response_orig = node.parse_response
+ with mock.patch.object(
+ ClusterNode, "parse_response", autospec=True
+ ) as parse_response_mock:
+
+ async def parse_response(
+ self, connection: Connection, command: str, **kwargs: Any
+ ) -> Any:
+ if command == "GET":
+ raise ClusterDownError("error")
+ return await parse_response_orig(connection, command, **kwargs)
+
+ parse_response_mock.side_effect = parse_response
+
+ # For each ClusterDownError, we launch 4 commands: INFO, CLUSTER SLOTS,
+ # COMMAND, GET. Before any errors, the first 3 commands are already run
+ async with r.pipeline() as pipe:
+ with pytest.raises(ClusterDownError):
+ await pipe.get(key).execute()
+
+ assert (
+ node.parse_response.await_count
+ == 4 * r.cluster_error_retry_attempts - 3
+ )
+
+ async def test_connection_error_not_raised(self, r: RedisCluster) -> None:
+ """Test ConnectionError handling with raise_on_error=False."""
+ key = "foo"
+ node = r.get_node_from_key(key, False)
+
+ parse_response_orig = node.parse_response
+ with mock.patch.object(
+ ClusterNode, "parse_response", autospec=True
+ ) as parse_response_mock:
+
+ async def parse_response(
+ self, connection: Connection, command: str, **kwargs: Any
+ ) -> Any:
+ if command == "GET":
+ raise ConnectionError("error")
+ return await parse_response_orig(connection, command, **kwargs)
+
+ parse_response_mock.side_effect = parse_response
+
+ async with r.pipeline() as pipe:
+ res = await pipe.get(key).get(key).execute(raise_on_error=False)
+ assert node.parse_response.await_count
+ assert isinstance(res[0], ConnectionError)
+
+ async def test_connection_error_raised(self, r: RedisCluster) -> None:
+ """Test ConnectionError handling with raise_on_error=True."""
+ key = "foo"
+ node = r.get_node_from_key(key, False)
+
+ parse_response_orig = node.parse_response
+ with mock.patch.object(
+ ClusterNode, "parse_response", autospec=True
+ ) as parse_response_mock:
+
+ async def parse_response(
+ self, connection: Connection, command: str, **kwargs: Any
+ ) -> Any:
+ if command == "GET":
+ raise ConnectionError("error")
+ return await parse_response_orig(connection, command, **kwargs)
+
+ parse_response_mock.side_effect = parse_response
+
+ async with r.pipeline() as pipe:
+ with pytest.raises(ConnectionError):
+ await pipe.get(key).get(key).execute(raise_on_error=True)
+
+ async def test_asking_error(self, r: RedisCluster) -> None:
+ """Test AskError handling."""
+ key = "foo"
+ first_node = r.get_node_from_key(key, False)
+ ask_node = None
+ for node in r.get_nodes():
+ if node != first_node:
+ ask_node = node
+ break
+ ask_msg = f"{r.keyslot(key)} {ask_node.host}:{ask_node.port}"
+
+ async with r.pipeline() as pipe:
+ mock_node_resp_exc(first_node, AskError(ask_msg))
+ mock_node_resp(ask_node, "MOCK_OK")
+ res = await pipe.get(key).execute()
+ assert first_node._free.pop().read_response_without_lock.await_count
+ assert ask_node._free.pop().read_response_without_lock.await_count
+ assert res == ["MOCK_OK"]
+
+ async def test_moved_redirection_on_slave_with_default(
+ self, r: RedisCluster
+ ) -> None:
+ """Test MovedError handling."""
+ key = "foo"
+ await r.set("foo", "bar")
+ # set read_from_replicas to True
+ r.read_from_replicas = True
+ primary = r.get_node_from_key(key, False)
+ moved_error = f"{r.keyslot(key)} {primary.host}:{primary.port}"
+
+ parse_response_orig = primary.parse_response
+ with mock.patch.object(
+ ClusterNode, "parse_response", autospec=True
+ ) as parse_response_mock:
+
+ async def parse_response(
+ self, connection: Connection, command: str, **kwargs: Any
+ ) -> Any:
+ if (
+ command == "GET"
+ and self.host != primary.host
+ and self.port != primary.port
+ ):
+ raise MovedError(moved_error)
+
+ return await parse_response_orig(connection, command, **kwargs)
+
+ parse_response_mock.side_effect = parse_response
+
+ async with r.pipeline() as readwrite_pipe:
+ assert r.reinitialize_counter == 0
+ readwrite_pipe.get(key).get(key)
+ assert r.reinitialize_counter == 0
+ assert await readwrite_pipe.execute() == [b"bar", b"bar"]
+
+ async def test_readonly_pipeline_from_readonly_client(
+ self, r: RedisCluster
+ ) -> None:
+ """Test that the pipeline uses replicas for read_from_replicas clients."""
+ # Create a cluster with reading from replications
+ r.read_from_replicas = True
+ key = "bar"
+ await r.set(key, "foo")
+
+ async with r.pipeline() as pipe:
+ mock_all_nodes_resp(r, "MOCK_OK")
+ assert await pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"]
+ slot_nodes = r.nodes_manager.slots_cache[r.keyslot(key)]
+ executed_on_replica = False
+ for node in slot_nodes:
+ if node.server_type == REPLICA:
+ if node._free.pop().read_response_without_lock.await_count:
+ executed_on_replica = True
+ break
+ assert executed_on_replica
diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py
index 50a1051..dfeb664 100644
--- a/tests/test_asyncio/test_pipeline.py
+++ b/tests/test_asyncio/test_pipeline.py
@@ -8,8 +8,8 @@ from .conftest import wait_for_command
pytestmark = pytest.mark.asyncio
-@pytest.mark.onlynoncluster
class TestPipeline:
+ @pytest.mark.onlynoncluster
async def test_pipeline_is_true(self, r):
"""Ensure pipeline instances are not false-y"""
async with r.pipeline() as pipe:
@@ -52,7 +52,6 @@ class TestPipeline:
await pipe.execute()
assert len(pipe) == 0
- @pytest.mark.onlynoncluster
async def test_pipeline_no_transaction(self, r):
async with r.pipeline(transaction=False) as pipe:
pipe.set("a", "a1").set("b", "b1").set("c", "c1")
@@ -61,6 +60,7 @@ class TestPipeline:
assert await r.get("b") == b"b1"
assert await r.get("c") == b"c1"
+ @pytest.mark.onlynoncluster
async def test_pipeline_no_transaction_watch(self, r):
await r.set("a", 0)
@@ -72,6 +72,7 @@ class TestPipeline:
pipe.set("a", int(a) + 1)
assert await pipe.execute() == [True]
+ @pytest.mark.onlynoncluster
async def test_pipeline_no_transaction_watch_failure(self, r):
await r.set("a", 0)
@@ -375,7 +376,7 @@ class TestPipeline:
async def test_pipeline_get(self, r):
await r.set("a", "a1")
async with r.pipeline() as pipe:
- await pipe.get("a")
+ pipe.get("a")
assert await pipe.execute() == [b"a1"]
@pytest.mark.onlynoncluster