summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorUtkarsh Gupta <utkarshgupta137@gmail.com>2022-05-30 19:04:06 +0530
committerGitHub <noreply@github.com>2022-05-30 16:34:06 +0300
commitf704281cf4c1f735c06a13946fcea42fa939e3a5 (patch)
tree85a7affc680058b54d1df30d65e1f97c44c08847
parent48079083a7f6ac1bdd948c03175f9ffd42aa1f6b (diff)
downloadredis-py-f704281cf4c1f735c06a13946fcea42fa939e3a5.tar.gz
async_cluster: add/update typing (#2195)
* async_cluster: add/update typing * async_cluster: update cleanup_kwargs with kwargs from async Connection * async_cluster: properly remove old nodes
-rw-r--r--redis/asyncio/cluster.py162
-rw-r--r--redis/asyncio/connection.py8
-rw-r--r--redis/asyncio/parser.py14
-rw-r--r--redis/cluster.py15
-rw-r--r--redis/utils.py5
-rw-r--r--tests/test_asyncio/test_cluster.py35
6 files changed, 123 insertions, 116 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index fa91872..c4e01d6 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -3,11 +3,12 @@ import collections
import random
import socket
import warnings
-from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
+from typing import Any, Deque, Dict, Generator, List, Optional, Type, TypeVar, Union
-from redis.asyncio.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
+from redis.asyncio.client import ResponseCallbackT
from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url
from redis.asyncio.parser import CommandsParser
+from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
from redis.cluster import (
PRIMARY,
READ_COMMANDS,
@@ -15,7 +16,6 @@ from redis.cluster import (
SLOT_ID,
AbstractRedisCluster,
LoadBalancer,
- cleanup_kwargs,
get_node_name,
parse_cluster_slots,
)
@@ -41,9 +41,36 @@ from redis.typing import EncodableT, KeyT
from redis.utils import dict_merge, str_if_bytes
TargetNodesT = TypeVar(
- "TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
+ "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
)
+CONNECTION_ALLOWED_KEYS = (
+ "client_name",
+ "db",
+ "decode_responses",
+ "encoder_class",
+ "encoding",
+ "encoding_errors",
+ "health_check_interval",
+ "parser_class",
+ "password",
+ "redis_connect_func",
+ "retry",
+ "retry_on_timeout",
+ "socket_connect_timeout",
+ "socket_keepalive",
+ "socket_keepalive_options",
+ "socket_read_size",
+ "socket_timeout",
+ "socket_type",
+ "username",
+)
+
+
+def cleanup_kwargs(**kwargs: Any) -> Dict[str, Any]:
+ """Remove unsupported or disabled keys from kwargs."""
+ return {k: v for k, v in kwargs.items() if k in CONNECTION_ALLOWED_KEYS}
+
class ClusterParser(DefaultParser):
EXCEPTION_CLASSES = dict_merge(
@@ -131,7 +158,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
"""
@classmethod
- def from_url(cls, url: str, **kwargs) -> "RedisCluster":
+ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
"""
Return a Redis client object configured from the given URL.
@@ -201,7 +228,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
cluster_error_retry_attempts: int = 3,
reinitialize_steps: int = 10,
url: Optional[str] = None,
- **kwargs,
+ **kwargs: Any,
) -> None:
if not startup_nodes:
startup_nodes = []
@@ -212,7 +239,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
"Argument 'db' is not possible to use in cluster mode"
)
- # Get the startup node/s
+ # Get the startup node(s)
if url:
url_options = parse_url(url)
if "path" in url_options:
@@ -247,34 +274,34 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs)
self.response_callbacks = kwargs[
"response_callbacks"
- ] = self.__class__.RESPONSE_CALLBACKS
+ ] = self.__class__.RESPONSE_CALLBACKS.copy()
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
+ self.nodes_manager = NodesManager(
+ startup_nodes=startup_nodes,
+ require_full_coverage=require_full_coverage,
+ **self.connection_kwargs,
+ )
self.encoder = Encoder(
kwargs.get("encoding", "utf-8"),
kwargs.get("encoding_errors", "strict"),
kwargs.get("decode_responses", False),
)
self.cluster_error_retry_attempts = cluster_error_retry_attempts
- self.command_flags = self.__class__.COMMAND_FLAGS.copy()
- self.node_flags = self.__class__.NODE_FLAGS.copy()
self.read_from_replicas = read_from_replicas
- self.reinitialize_counter = 0
self.reinitialize_steps = reinitialize_steps
- self.nodes_manager = NodesManager(
- startup_nodes=startup_nodes,
- require_full_coverage=require_full_coverage,
- **self.connection_kwargs,
- )
- self.result_callbacks = self.__class__.RESULT_CALLBACKS
+ self.reinitialize_counter = 0
+ self.commands_parser = CommandsParser()
+ self.node_flags = self.__class__.NODE_FLAGS.copy()
+ self.command_flags = self.__class__.COMMAND_FLAGS.copy()
+ self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
self.result_callbacks[
"CLUSTER SLOTS"
] = lambda cmd, res, **kwargs: parse_cluster_slots(
list(res.values())[0], **kwargs
)
- self.commands_parser = CommandsParser()
self._initialize = True
self._lock = asyncio.Lock()
@@ -310,16 +337,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
await self.close()
- def __await__(self):
+ def __await__(self) -> Generator[Any, None, "RedisCluster"]:
return self.initialize().__await__()
_DEL_MESSAGE = "Unclosed RedisCluster client"
- def __del__(self, _warnings=warnings):
+ def __del__(self) -> None:
if hasattr(self, "_initialize") and not self._initialize:
- _warnings.warn(
- f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self
- )
+ warnings.warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
try:
context = {"client": self, "message": self._DEL_MESSAGE}
# TODO: Change to get_running_loop() when dropping support for py3.6
@@ -408,7 +433,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
self.nodes_manager.default_node = node
- def set_response_callback(self, command: KeyT, callback: Callable) -> None:
+ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
"""Set a custom response callback."""
self.response_callbacks[command] = callback
@@ -430,7 +455,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
return key_slot(k)
async def _determine_nodes(
- self, *args, node_flag: Optional[str] = None
+ self, *args: Any, node_flag: Optional[str] = None
) -> List["ClusterNode"]:
command = args[0]
if not node_flag:
@@ -462,11 +487,11 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
)
]
- async def _determine_slot(self, *args) -> int:
+ async def _determine_slot(self, *args: Any) -> int:
command = args[0]
if self.command_flags.get(command) == SLOT_ID:
# The command contains the slot ID
- return args[1]
+ return int(args[1])
# Get the keys in the command
@@ -516,14 +541,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
return slots.pop()
- def _is_node_flag(
- self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str]
- ) -> bool:
+ def _is_node_flag(self, target_nodes: Any) -> bool:
return isinstance(target_nodes, str) and target_nodes in self.node_flags
- def _parse_target_nodes(
- self, target_nodes: Union[List["ClusterNode"], "ClusterNode"]
- ) -> List["ClusterNode"]:
+ def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
if isinstance(target_nodes, list):
nodes = target_nodes
elif isinstance(target_nodes, ClusterNode):
@@ -533,7 +554,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
# Supports dictionaries of the format {node_name: node}.
# It enables to execute commands with multi nodes as follows:
# rc.cluster_save_config(rc.get_primaries())
- nodes = target_nodes.values()
+ nodes = list(target_nodes.values())
else:
raise TypeError(
"target_nodes type can be one of the following: "
@@ -543,7 +564,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
)
return nodes
- async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any:
+ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
"""
Execute a raw command on the appropriate cluster node or target_nodes.
@@ -562,7 +583,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
can't be mapped to a slot
"""
command = args[0]
- target_nodes_specified = target_nodes = exception = None
+ target_nodes = []
+ target_nodes_specified = False
retry_attempts = self.cluster_error_retry_attempts
passed_targets = kwargs.pop("target_nodes", None)
@@ -571,7 +593,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
target_nodes_specified = True
retry_attempts = 1
- for _ in range(0, retry_attempts):
+ for _ in range(retry_attempts):
if self._initialize:
await self.initialize()
try:
@@ -622,9 +644,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
raise exception
async def _execute_command(
- self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs
+ self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
) -> Any:
- redirect_addr = asking = moved = None
+ asking = moved = False
+ redirect_addr = None
ttl = self.RedisClusterRequestTTL
connection_error_retry_counter = 0
@@ -725,8 +748,8 @@ class ClusterNode:
server_type: Optional[str] = None,
max_connections: int = 2 ** 31,
connection_class: Type[Connection] = Connection,
- response_callbacks: Dict = RedisCluster.RESPONSE_CALLBACKS,
- **connection_kwargs,
+ response_callbacks: Dict[str, Any] = RedisCluster.RESPONSE_CALLBACKS,
+ **connection_kwargs: Any,
) -> None:
if host == "localhost":
host = socket.gethostbyname(host)
@@ -743,8 +766,8 @@ class ClusterNode:
self.connection_kwargs = connection_kwargs
self.response_callbacks = response_callbacks
- self._connections = []
- self._free = collections.deque(maxlen=self.max_connections)
+ self._connections: List[Connection] = []
+ self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
def __repr__(self) -> str:
return (
@@ -752,15 +775,15 @@ class ClusterNode:
f"name={self.name}, server_type={self.server_type}]"
)
- def __eq__(self, obj: "ClusterNode") -> bool:
+ def __eq__(self, obj: Any) -> bool:
return isinstance(obj, ClusterNode) and obj.name == self.name
_DEL_MESSAGE = "Unclosed ClusterNode object"
- def __del__(self, _warnings=warnings):
+ def __del__(self) -> None:
for connection in self._connections:
if connection.is_connected:
- _warnings.warn(
+ warnings.warn(
f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self
)
try:
@@ -783,7 +806,7 @@ class ClusterNode:
if exc:
raise exc
- async def execute_command(self, *args, **kwargs) -> Any:
+ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
# Acquire connection
connection = None
if self._free:
@@ -829,11 +852,11 @@ class ClusterNode:
class NodesManager:
__slots__ = (
"_moved_exception",
- "_require_full_coverage",
"connection_kwargs",
"default_node",
"nodes_cache",
"read_load_balancer",
+ "require_full_coverage",
"slots_cache",
"startup_nodes",
)
@@ -842,23 +865,24 @@ class NodesManager:
self,
startup_nodes: List["ClusterNode"],
require_full_coverage: bool = False,
- **kwargs,
+ **kwargs: Any,
) -> None:
- self.nodes_cache = {}
- self.slots_cache = {}
self.startup_nodes = {node.name: node for node in startup_nodes}
- self.default_node = None
- self._require_full_coverage = require_full_coverage
- self._moved_exception = None
+ self.require_full_coverage = require_full_coverage
self.connection_kwargs = kwargs
+
+ self.default_node: "ClusterNode" = None
+ self.nodes_cache: Dict[str, "ClusterNode"] = {}
+ self.slots_cache: Dict[int, List["ClusterNode"]] = {}
self.read_load_balancer = LoadBalancer()
+ self._moved_exception: MovedError = None
def get_node(
self,
host: Optional[str] = None,
port: Optional[int] = None,
node_name: Optional[str] = None,
- ) -> "ClusterNode":
+ ) -> Optional["ClusterNode"]:
if host and port:
# the user passed host and port
if host == "localhost":
@@ -877,20 +901,18 @@ class NodesManager:
self,
old: Dict[str, "ClusterNode"],
new: Dict[str, "ClusterNode"],
- remove_old=False,
+ remove_old: bool = False,
) -> None:
- tasks = []
if remove_old:
- tasks = [
- asyncio.ensure_future(node.disconnect())
- for name, node in old.items()
- if name not in new
- ]
+ for name in list(old.keys()):
+ if name not in new:
+ asyncio.ensure_future(old.pop(name).disconnect())
+
for name, node in new.items():
if name in old:
if old[name] is node:
continue
- tasks.append(asyncio.ensure_future(old[name].disconnect()))
+ asyncio.ensure_future(old[name].disconnect())
old[name] = node
def _update_moved_slots(self) -> None:
@@ -949,7 +971,7 @@ class NodesManager:
except (IndexError, TypeError):
raise SlotNotCoveredError(
f'Slot "{slot}" not covered by the cluster. '
- f'"require_full_coverage={self._require_full_coverage}"'
+ f'"require_full_coverage={self.require_full_coverage}"'
)
def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
@@ -961,8 +983,8 @@ class NodesManager:
async def initialize(self) -> None:
self.read_load_balancer.reset()
- tmp_nodes_cache = {}
- tmp_slots = {}
+ tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
+ tmp_slots: Dict[int, List["ClusterNode"]] = {}
disagreements = []
startup_nodes_reachable = False
fully_covered = False
@@ -975,9 +997,7 @@ class NodesManager:
raise RedisClusterException(
"Cluster mode is not enabled on this node"
)
- cluster_slots = str_if_bytes(
- await startup_node.execute_command("CLUSTER SLOTS")
- )
+ cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
startup_nodes_reachable = True
except (ConnectionError, TimeoutError):
continue
@@ -1069,7 +1089,7 @@ class NodesManager:
# Validate if all slots are covered or if we should try next startup node
fully_covered = True
- for i in range(0, REDIS_CLUSTER_HASH_SLOTS):
+ for i in range(REDIS_CLUSTER_HASH_SLOTS):
if i not in tmp_slots:
fully_covered = False
break
@@ -1083,7 +1103,7 @@ class NodesManager:
)
# Check if the slots are not fully covered
- if not fully_covered and self._require_full_coverage:
+ if not fully_covered and self.require_full_coverage:
# Despite the requirement that the slots be covered, there
# isn't a full coverage
raise RedisClusterException(
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 9de2d46..d9b0974 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -684,7 +684,7 @@ class Connection:
def clear_connect_callbacks(self):
self._connect_callbacks = []
- def set_parser(self, parser_class):
+ def set_parser(self, parser_class: Type[BaseParser]) -> None:
"""
Creates a new instance of parser_class with socket size:
_socket_read_size and assigns it to the parser for the connection
@@ -766,7 +766,7 @@ class Connection:
f"{exception.args[0]}."
)
- async def on_connect(self):
+ async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
self._parser.on_connect(self)
@@ -807,7 +807,7 @@ class Connection:
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Invalid Database")
- async def disconnect(self):
+ async def disconnect(self) -> None:
"""Disconnects from the Redis server"""
try:
async with async_timeout.timeout(self.socket_connect_timeout):
@@ -891,7 +891,7 @@ class Connection:
await self.disconnect()
raise
- async def send_command(self, *args, **kwargs):
+ async def send_command(self, *args: Any, **kwargs: Any) -> None:
"""Pack and send a command to the Redis server"""
await self.send_packed_command(
self.pack_command(*args), check_health=kwargs.get("check_health", True)
diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py
index 273fe03..6286351 100644
--- a/redis/asyncio/parser.py
+++ b/redis/asyncio/parser.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from redis.exceptions import RedisError, ResponseError
@@ -25,7 +25,7 @@ class CommandsParser:
__slots__ = ("commands",)
def __init__(self) -> None:
- self.commands = {}
+ self.commands: Dict[str, Union[int, Dict[str, Any]]] = {}
async def initialize(self, r: "ClusterNode") -> None:
commands = await r.execute_command("COMMAND")
@@ -42,8 +42,8 @@ class CommandsParser:
# our logic to use COMMAND INFO changes to determine the key positions
# https://github.com/redis/redis/pull/8324
async def get_keys(
- self, redis_conn: "ClusterNode", *args
- ) -> Optional[Union[List[str], List[bytes]]]:
+ self, redis_conn: "ClusterNode", *args: Any
+ ) -> Optional[Tuple[str, ...]]:
if len(args) < 2:
# The command has no keys in it
return None
@@ -67,7 +67,7 @@ class CommandsParser:
command = self.commands[cmd_name]
if command == 1:
- return [args[1]]
+ return (args[1],)
if command == 0:
return None
if command == -1:
@@ -79,8 +79,8 @@ class CommandsParser:
return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]]
async def _get_moveable_keys(
- self, redis_conn: "ClusterNode", *args
- ) -> Optional[List[str]]:
+ self, redis_conn: "ClusterNode", *args: Any
+ ) -> Optional[Tuple[str, ...]]:
try:
keys = await redis_conn.execute_command("COMMAND GETKEYS", *args)
except ResponseError as e:
diff --git a/redis/cluster.py b/redis/cluster.py
index 0b9c543..46a96a6 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -6,6 +6,7 @@ import sys
import threading
import time
from collections import OrderedDict
+from typing import Any, Dict, Tuple
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
from redis.commands import CommandsParser, RedisClusterCommands
@@ -40,7 +41,7 @@ from redis.utils import (
log = logging.getLogger(__name__)
-def get_node_name(host, port):
+def get_node_name(host: str, port: int) -> str:
return f"{host}:{port}"
@@ -74,10 +75,12 @@ def parse_pubsub_numsub(command, res, **options):
return ret_numsub
-def parse_cluster_slots(resp, **options):
+def parse_cluster_slots(
+ resp: Any, **options: Any
+) -> Dict[Tuple[int, int], Dict[str, Any]]:
current_host = options.get("current_host", "")
- def fix_server(*args):
+ def fix_server(*args: Any) -> Tuple[str, Any]:
return str_if_bytes(args[0]) or current_host, args[1]
slots = {}
@@ -1248,17 +1251,17 @@ class LoadBalancer:
Round-Robin Load Balancing
"""
- def __init__(self, start_index=0):
+ def __init__(self, start_index: int = 0) -> None:
self.primary_to_idx = {}
self.start_index = start_index
- def get_server_index(self, primary, list_size):
+ def get_server_index(self, primary: str, list_size: int) -> int:
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
# Update the index
self.primary_to_idx[primary] = (server_index + 1) % list_size
return server_index
- def reset(self):
+ def reset(self) -> None:
self.primary_to_idx.clear()
diff --git a/redis/utils.py b/redis/utils.py
index 9ab75f2..0c34e1e 100644
--- a/redis/utils.py
+++ b/redis/utils.py
@@ -1,4 +1,5 @@
from contextlib import contextmanager
+from typing import Any, Dict, Mapping, Union
try:
import hiredis # noqa
@@ -34,7 +35,7 @@ def pipeline(redis_obj):
p.execute()
-def str_if_bytes(value):
+def str_if_bytes(value: Union[str, bytes]) -> str:
return (
value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value
)
@@ -44,7 +45,7 @@ def safe_str(value):
return str(str_if_bytes(value))
-def dict_merge(*dicts):
+def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]:
"""
Merge all provided dicts into 1 dict.
*dicts : `dict`
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 6c28ce3..123adc8 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -3,6 +3,7 @@ import binascii
import datetime
import sys
import warnings
+from typing import Any, Callable, Dict, List, Optional, Type, Union
import pytest
@@ -13,21 +14,13 @@ if sys.version_info[0:2] == (3, 6):
else:
import pytest_asyncio
-from typing import Callable, Dict, List, Optional, Type, Union
-
from _pytest.fixtures import FixtureRequest, SubRequest
from redis.asyncio import Connection, RedisCluster
-from redis.asyncio.cluster import (
- PRIMARY,
- REDIS_CLUSTER_HASH_SLOTS,
- REPLICA,
- ClusterNode,
- NodesManager,
- get_node_name,
-)
+from redis.asyncio.cluster import ClusterNode, NodesManager
from redis.asyncio.parser import CommandsParser
-from redis.crc import key_slot
+from redis.cluster import PRIMARY, REPLICA, get_node_name
+from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
from redis.exceptions import (
AskError,
ClusterDownError,
@@ -109,7 +102,7 @@ async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster:
CommandsParser, "initialize", autospec=True
) as cmd_parser_initialize:
- def cmd_init_mock(self, r):
+ def cmd_init_mock(self, r: ClusterNode) -> None:
self.commands = {
"GET": {
"name": "get",
@@ -126,12 +119,7 @@ async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster:
return await RedisCluster(*args, **kwargs)
-def mock_node_resp(
- node: ClusterNode,
- response: Union[
- List[List[Union[int, List[Union[str, int]]]]], List[bytes], str, int
- ],
-) -> ClusterNode:
+def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode:
connection = mock.AsyncMock()
connection.is_connected = True
connection.read_response_without_lock.return_value = response
@@ -141,12 +129,7 @@ def mock_node_resp(
return node
-def mock_all_nodes_resp(
- rc: RedisCluster,
- response: Union[
- List[List[Union[int, List[Union[str, int]]]]], List[bytes], int, str
- ],
-) -> RedisCluster:
+def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster:
for node in rc.get_nodes():
mock_node_resp(node, response)
return rc
@@ -461,7 +444,7 @@ class TestRedisClusterObj:
CommandsParser, "initialize", autospec=True
) as cmd_parser_initialize:
- def cmd_init_mock(self, r):
+ def cmd_init_mock(self, r: ClusterNode) -> None:
self.commands = {
"GET": {
"name": "get",
@@ -2217,7 +2200,7 @@ class TestNodesManager:
CommandsParser, "initialize", autospec=True
) as cmd_parser_initialize:
- def cmd_init_mock(self, r):
+ def cmd_init_mock(self, r: ClusterNode) -> None:
self.commands = {
"GET": {
"name": "get",