diff options
author | Utkarsh Gupta <utkarshgupta137@gmail.com> | 2022-05-08 17:34:20 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-08 15:04:20 +0300 |
commit | 061d97abe21d3a8ce9738330cabf771dd05c8dc1 (patch) | |
tree | e64d64c5917312a65304f4d630775d08adb38bfa | |
parent | c25be04d6468163d31908774ed358d3fd6bc0a39 (diff) | |
download | redis-py-061d97abe21d3a8ce9738330cabf771dd05c8dc1.tar.gz |
Add Async RedisCluster (#2099)
* Copy Cluster Client, Commands, Commands Parser, Tests for asyncio
* Async Cluster Tests: Async/Await
* Add Async RedisCluster
* cluster: use ERRORS_ALLOW_RETRY from self.__class__
* async_cluster: rework redis_connection, initialize, & close
- move redis_connection from NodesManager to ClusterNode & handle all related logic in ClusterNode class
- use Locks while initializing or closing
- in case of error, close connections instead of instantly reinitializing
- create ResourceWarning instead of manually deleting client object
- use asyncio.gather to run commands/initialize/close in parallel
- inline single use functions
- fix test_acl_log for py3.6
* async_cluster: add types
* async_cluster: add docs
* docs: update sphinx & add sphinx_autodoc_typehints
* async_cluster: move TargetNodesT to cluster module
* async_cluster/commands: inherit commands from sync class if possible
* async_cluster: add benchmark script with aredis & aioredis-cluster
* async_cluster: remove logging
* async_cluster: inline functions
* async_cluster: manage Connection instead of Redis Client
* async_cluster/commands: optimize parser
* async_cluster: use ensure_future & generators for gather
* async_conn: optimize
* async_cluster: optimize determine_slot
* async_cluster: optimize determine_nodes
* async_cluster/parser: optimize _get_moveable_keys
* async_cluster: inlined check_slots_coverage
* async_cluster: update docstrings
* async_cluster: add concurrent test & use read_response/_update_moved_slots without lock
Co-authored-by: Chayim <chayim@users.noreply.github.com>
-rw-r--r-- | benchmarks/cluster_async.py | 263 | ||||
-rw-r--r-- | docs/conf.py | 9 | ||||
-rw-r--r-- | docs/connections.rst | 74 | ||||
-rw-r--r-- | docs/requirements.txt | 3 | ||||
-rw-r--r-- | redis/asyncio/__init__.py | 4 | ||||
-rw-r--r-- | redis/asyncio/client.py | 12 | ||||
-rw-r--r-- | redis/asyncio/cluster.py | 1113 | ||||
-rw-r--r-- | redis/asyncio/connection.py | 140 | ||||
-rw-r--r-- | redis/asyncio/parser.py | 95 | ||||
-rwxr-xr-x | redis/client.py | 17 | ||||
-rw-r--r-- | redis/cluster.py | 200 | ||||
-rw-r--r-- | redis/commands/__init__.py | 3 | ||||
-rw-r--r-- | redis/commands/cluster.py | 567 | ||||
-rw-r--r-- | redis/commands/parser.py | 1 | ||||
-rw-r--r-- | redis/crc.py | 4 | ||||
-rw-r--r-- | redis/typing.py | 12 | ||||
-rw-r--r-- | tests/conftest.py | 11 | ||||
-rw-r--r-- | tests/test_asyncio/conftest.py | 66 | ||||
-rw-r--r-- | tests/test_asyncio/test_cluster.py | 2232 | ||||
-rw-r--r-- | tests/test_asyncio/test_commands.py | 141 | ||||
-rw-r--r-- | tests/test_asyncio/test_connection.py | 2 | ||||
-rw-r--r-- | tests/test_asyncio/test_connection_pool.py | 79 | ||||
-rw-r--r-- | tests/test_asyncio/test_lock.py | 2 | ||||
-rw-r--r-- | tests/test_asyncio/test_retry.py | 2 | ||||
-rw-r--r-- | tests/test_cluster.py | 6 | ||||
-rw-r--r-- | tox.ini | 5 | ||||
-rw-r--r-- | whitelist.py | 2 |
27 files changed, 4541 insertions, 524 deletions
diff --git a/benchmarks/cluster_async.py b/benchmarks/cluster_async.py new file mode 100644 index 0000000..aec3f1c --- /dev/null +++ b/benchmarks/cluster_async.py @@ -0,0 +1,263 @@ +import asyncio +import functools +import time + +import aioredis_cluster +import aredis +import uvloop + +import redis.asyncio as redispy + + +def timer(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + tic = time.perf_counter() + await func(*args, **kwargs) + toc = time.perf_counter() + return f"{toc - tic:.4f}" + + return wrapper + + +@timer +async def set_str(client, gather, data): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *( + asyncio.create_task(client.set(f"bench:str_{i}", data)) + for i in range(100) + ) + ) + else: + for i in range(count): + await client.set(f"bench:str_{i}", data) + + +@timer +async def set_int(client, gather, data): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *( + asyncio.create_task(client.set(f"bench:int_{i}", data)) + for i in range(100) + ) + ) + else: + for i in range(count): + await client.set(f"bench:int_{i}", data) + + +@timer +async def get_str(client, gather): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *(asyncio.create_task(client.get(f"bench:str_{i}")) for i in range(100)) + ) + else: + for i in range(count): + await client.get(f"bench:str_{i}") + + +@timer +async def get_int(client, gather): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *(asyncio.create_task(client.get(f"bench:int_{i}")) for i in range(100)) + ) + else: + for i in range(count): + await client.get(f"bench:int_{i}") + + +@timer +async def hset(client, gather, data): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *( + asyncio.create_task(client.hset("bench:hset", str(i), data)) + for i in range(100) + ) + ) + else: + for i in range(count): + await client.hset("bench:hset", str(i), data) + + +@timer +async def hget(client, gather): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *( + asyncio.create_task(client.hget("bench:hset", str(i))) + for i in range(100) + ) + ) + else: + for i in range(count): + await client.hget("bench:hset", str(i)) + + +@timer +async def incr(client, gather): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *(asyncio.create_task(client.incr("bench:incr")) for i in range(100)) + ) + else: + for i in range(count): + await client.incr("bench:incr") + + +@timer +async def lpush(client, gather, data): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *( + asyncio.create_task(client.lpush("bench:lpush", data)) + for i in range(100) + ) + ) + else: + for i in range(count): + await client.lpush("bench:lpush", data) + + +@timer +async def lrange_300(client, gather): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *( + asyncio.create_task(client.lrange("bench:lpush", i, i + 300)) + for i in range(100) + ) + ) + else: + for i in range(count): + await client.lrange("bench:lpush", i, i + 300) + + +@timer +async def lpop(client, gather): + if gather: + for _ in range(count // 100): + await asyncio.gather( + *(asyncio.create_task(client.lpop("bench:lpush")) for i in range(100)) + ) + else: + for i in range(count): + await client.lpop("bench:lpush") + + +@timer +async def warmup(client): + await asyncio.gather( + *(asyncio.create_task(client.exists(f"bench:warmup_{i}")) for i in range(100)) + ) + + +@timer +async def run(client, gather): + data_str = "a" * size + data_int = int("1" * size) + + if gather is False: + for ret in await asyncio.gather( + asyncio.create_task(set_str(client, gather, data_str)), + asyncio.create_task(set_int(client, gather, data_int)), + asyncio.create_task(hset(client, gather, data_str)), + asyncio.create_task(incr(client, gather)), + asyncio.create_task(lpush(client, gather, data_int)), + ): + print(ret) + for ret in await asyncio.gather( + asyncio.create_task(get_str(client, gather)), + asyncio.create_task(get_int(client, gather)), + asyncio.create_task(hget(client, gather)), + asyncio.create_task(lrange_300(client, gather)), + asyncio.create_task(lpop(client, gather)), + ): + print(ret) + else: + print(await set_str(client, gather, data_str)) + print(await set_int(client, gather, data_int)) + print(await hset(client, gather, data_str)) + print(await incr(client, gather)) + print(await lpush(client, gather, data_int)) + + print(await get_str(client, gather)) + print(await get_int(client, gather)) + print(await hget(client, gather)) + print(await lrange_300(client, gather)) + print(await lpop(client, gather)) + + +async def main(loop, gather=None): + arc = aredis.StrictRedisCluster( + host=host, + port=port, + password=password, + max_connections=2 ** 31, + max_connections_per_node=2 ** 31, + readonly=False, + reinitialize_steps=count, + skip_full_coverage_check=True, + decode_responses=False, + max_idle_time=count, + idle_check_interval=count, + ) + print(f"{loop} {gather} {await warmup(arc)} aredis") + print(await run(arc, gather=gather)) + arc.connection_pool.disconnect() + + aiorc = await aioredis_cluster.create_redis_cluster( + [(host, port)], + password=password, + state_reload_interval=count, + idle_connection_timeout=count, + pool_maxsize=2 ** 31, + ) + print(f"{loop} {gather} {await warmup(aiorc)} aioredis-cluster") + print(await run(aiorc, gather=gather)) + aiorc.close() + await aiorc.wait_closed() + + async with redispy.RedisCluster( + host=host, + port=port, + password=password, + reinitialize_steps=count, + read_from_replicas=False, + decode_responses=False, + max_connections=2 ** 31, + ) as rca: + print(f"{loop} {gather} {await warmup(rca)} redispy") + print(await run(rca, gather=gather)) + + +if __name__ == "__main__": + host = "localhost" + port = 16379 + password = None + + count = 1000 + size = 16 + + asyncio.run(main("asyncio")) + asyncio.run(main("asyncio", gather=False)) + asyncio.run(main("asyncio", gather=True)) + + uvloop.install() + + asyncio.run(main("uvloop")) + asyncio.run(main("uvloop", gather=False)) + asyncio.run(main("uvloop", gather=True)) diff --git a/docs/conf.py b/docs/conf.py index b99e46c..618d95a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,6 +30,7 @@ extensions = [ "nbsphinx", "sphinx_gallery.load_style", "sphinx.ext.autodoc", + "sphinx_autodoc_typehints", "sphinx.ext.doctest", "sphinx.ext.viewcode", "sphinx.ext.autosectionlabel", @@ -41,6 +42,10 @@ extensions = [ autosectionlabel_prefix_document = True autosectionlabel_maxdepth = 2 +# AutodocTypehints settings. +always_document_param_types = True +typehints_defaults = "comma" + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -210,7 +215,7 @@ latex_elements = { # (source start file, target name, title, author, documentclass # [howto/manual]). latex_documents = [ - ("index", "redis-py.tex", "redis-py Documentation", "Redis Inc", "manual"), + ("index", "redis-py.tex", "redis-py Documentation", "Redis Inc", "manual") ] # The name of an image file (relative to this directory) to place at the top of @@ -258,7 +263,7 @@ texinfo_documents = [ "redis-py", "One line description of project.", "Miscellaneous", - ), + ) ] # Documents to append as an appendix to all manuals. diff --git a/docs/connections.rst b/docs/connections.rst index 9804a15..e4b82cd 100644 --- a/docs/connections.rst +++ b/docs/connections.rst @@ -1,20 +1,22 @@ Connecting to Redis -##################### +################### + Generic Client ************** -This is the client used to connect directly to a standard redis node. +This is the client used to connect directly to a standard Redis node. .. autoclass:: redis.Redis :members: + Sentinel Client *************** -Redis `Sentinel <https://redis.io/topics/sentinel>`_ provides high availability for Redis. There are commands that can only be executed against a redis node running in sentinel mode. Connecting to those nodes, and executing commands against them requires a Sentinel connection. +Redis `Sentinel <https://redis.io/topics/sentinel>`_ provides high availability for Redis. There are commands that can only be executed against a Redis node running in sentinel mode. Connecting to those nodes, and executing commands against them requires a Sentinel connection. -Connection example (assumes redis redis on the ports listed below): +Connection example (assumes Redis exists on the ports listed below): >>> from redis import Sentinel >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) @@ -23,33 +25,85 @@ Connection example (assumes redis redis on the ports listed below): >>> sentinel.discover_slaves('mymaster') [('127.0.0.1', 6380)] +Sentinel +======== .. autoclass:: redis.sentinel.Sentinel :members: +SentinelConnectionPool +====================== .. autoclass:: redis.sentinel.SentinelConnectionPool :members: + Cluster Client ************** -This client is used for connecting to a redis cluser. +This client is used for connecting to a Redis Cluster. +RedisCluster +============ .. autoclass:: redis.cluster.RedisCluster :members: -Connection Pools -***************** -.. autoclass:: redis.connection.ConnectionPool +ClusterNode +=========== +.. autoclass:: redis.cluster.ClusterNode :members: -More connection examples can be found `here <examples/connection_examples.html>`_. Async Client ************ +See complete example: `here <examples/asyncio_examples.html>`_ + This client is used for communicating with Redis, asynchronously. +.. autoclass:: redis.asyncio.client.Redis + :members: + + +Async Cluster Client +******************** + +RedisCluster (Async) +==================== +.. autoclass:: redis.asyncio.cluster.RedisCluster + :members: + +ClusterNode (Async) +=================== +.. autoclass:: redis.asyncio.cluster.ClusterNode + :members: + + +Connection +********** + +See complete example: `here <examples/connection_examples.html>`_ + +Connection +========== +.. autoclass:: redis.connection.Connection + :members: + +Connection (Async) +================== .. autoclass:: redis.asyncio.connection.Connection :members: -More connection examples can be found `here <examples/asyncio_examples.html>`_
\ No newline at end of file + +Connection Pools +**************** + +See complete example: `here <examples/connection_examples.html>`_ + +ConnectionPool +============== +.. autoclass:: redis.connection.ConnectionPool + :members: + +ConnectionPool (Async) +====================== +.. autoclass:: redis.asyncio.connection.ConnectionPool + :members: diff --git a/docs/requirements.txt b/docs/requirements.txt index bbb7dc6..23ddc94 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,7 @@ -sphinx<2 +sphinx<5 docutils<0.18 sphinx-rtd-theme nbsphinx sphinx_gallery ipython +sphinx-autodoc-typehints diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index c655c7d..598791a 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -1,4 +1,5 @@ from redis.asyncio.client import Redis, StrictRedis +from redis.asyncio.cluster import RedisCluster from redis.asyncio.connection import ( BlockingConnectionPool, Connection, @@ -6,6 +7,7 @@ from redis.asyncio.connection import ( SSLConnection, UnixDomainSocketConnection, ) +from redis.asyncio.parser import CommandsParser from redis.asyncio.sentinel import ( Sentinel, SentinelConnectionPool, @@ -35,6 +37,7 @@ __all__ = [ "BlockingConnectionPool", "BusyLoadingError", "ChildDeadlockedError", + "CommandsParser", "Connection", "ConnectionError", "ConnectionPool", @@ -44,6 +47,7 @@ __all__ = [ "PubSubError", "ReadOnlyError", "Redis", + "RedisCluster", "RedisError", "ResponseError", "Sentinel", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 8dde96e..6db5489 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -172,6 +172,7 @@ class Redis( username: Optional[str] = None, retry: Optional[Retry] = None, auto_close_connection_pool: bool = True, + redis_connect_func=None, ): """ Initialize a new Redis client. @@ -200,6 +201,7 @@ class Redis( "max_connections": max_connections, "health_check_interval": health_check_interval, "client_name": client_name, + "redis_connect_func": redis_connect_func, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -263,11 +265,7 @@ class Redis( """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs - def load_external_module( - self, - funcname, - func, - ): + def load_external_module(self, funcname, func): """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -426,9 +424,7 @@ class Redis( def __del__(self, _warnings: Any = warnings) -> None: if self.connection is not None: _warnings.warn( - f"Unclosed client session {self!r}", - ResourceWarning, - source=self, + f"Unclosed client session {self!r}", ResourceWarning, source=self ) context = {"client": self, "message": self._DEL_MESSAGE} asyncio.get_event_loop().call_exception_handler(context) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py new file mode 100644 index 0000000..10a5675 --- /dev/null +++ b/redis/asyncio/cluster.py @@ -0,0 +1,1113 @@ +import asyncio +import collections +import random +import socket +import warnings +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union + +from redis.asyncio.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis +from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url +from redis.asyncio.parser import CommandsParser +from redis.cluster import ( + PRIMARY, + READ_COMMANDS, + REPLICA, + SLOT_ID, + AbstractRedisCluster, + LoadBalancer, + cleanup_kwargs, + get_node_name, + parse_cluster_slots, +) +from redis.commands import AsyncRedisClusterCommands +from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot +from redis.exceptions import ( + AskError, + BusyLoadingError, + ClusterCrossSlotError, + ClusterDownError, + ClusterError, + ConnectionError, + DataError, + MasterDownError, + MovedError, + RedisClusterException, + ResponseError, + SlotNotCoveredError, + TimeoutError, + TryAgainError, +) +from redis.typing import EncodableT, KeyT +from redis.utils import dict_merge, str_if_bytes + +TargetNodesT = TypeVar( + "TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] +) + + +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, + { + "ASK": AskError, + "TRYAGAIN": TryAgainError, + "MOVED": MovedError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, + }, + ) + + +class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): + """ + Create a new RedisCluster client. + + Pass one of parameters: + + - `url` + - `host` & `port` + - `startup_nodes` + + | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. + | Use ``await`` :meth:`close` to disconnect connections & close client. + + Many commands support the target_nodes kwarg. It can be one of the + :attr:`NODE_FLAGS`: + + - :attr:`PRIMARIES` + - :attr:`REPLICAS` + - :attr:`ALL_NODES` + - :attr:`RANDOM` + - :attr:`DEFAULT_NODE` + + Note: This client is not thread/process/fork safe. + + :param host: + | Can be used to point to a startup node + :param port: + | Port used if **host** is provided + :param startup_nodes: + | :class:`~.ClusterNode` to used as a startup node + :param cluster_error_retry_attempts: + | Retry command execution attempts when encountering :class:`~.ClusterDownError` + or :class:`~.ConnectionError` + :param require_full_coverage: + | When set to ``False``: the client will not require a full coverage of the + slots. However, if not all slots are covered, and at least one node has + ``cluster-require-full-coverage`` set to ``yes``, the server will throw a + :class:`~.ClusterDownError` for some key-based commands. + | When set to ``True``: all slots must be covered to construct the cluster + client. If not all slots are covered, :class:`~.RedisClusterException` will be + thrown. + | See: + https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters + :param reinitialize_steps: + | Specifies the number of MOVED errors that need to occur before reinitializing + the whole cluster topology. If a MOVED error occurs and the cluster does not + need to be reinitialized on this current error handling, only the MOVED slot + will be patched with the redirected node. + To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. + To avoid reinitializing the cluster on moved errors, set reinitialize_steps to + 0. + :param read_from_replicas: + | Enable read from replicas in READONLY mode. You can read possibly stale data. + When set to true, read commands will be assigned between the primary and + its replications in a Round-Robin manner. + :param url: + | See :meth:`.from_url` + :param kwargs: + | Extra arguments that will be passed to the + :class:`~redis.asyncio.connection.Connection` instances when created + + :raises RedisClusterException: + if any arguments are invalid. Eg: + + - db kwarg + - db != 0 in url + - unix socket connection + - none of host & url & startup_nodes were provided + + """ + + @classmethod + def from_url(cls, url: str, **kwargs) -> "RedisCluster": + """ + Return a Redis client object configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/redis> + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/rediss> + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments and + keyword arguments are passed to :class:`~redis.asyncio.connection.Connection` + when created. In the case of conflicting arguments, querystring + arguments always win. + + """ + return cls(url=url, **kwargs) + + __slots__ = ( + "_initialize", + "_lock", + "cluster_error_retry_attempts", + "command_flags", + "commands_parser", + "connection_kwargs", + "encoder", + "node_flags", + "nodes_manager", + "read_from_replicas", + "reinitialize_counter", + "reinitialize_steps", + "response_callbacks", + "result_callbacks", + ) + + def __init__( + self, + host: Optional[str] = None, + port: int = 6379, + startup_nodes: Optional[List["ClusterNode"]] = None, + require_full_coverage: bool = False, + read_from_replicas: bool = False, + cluster_error_retry_attempts: int = 3, + reinitialize_steps: int = 10, + url: Optional[str] = None, + **kwargs, + ) -> None: + if not startup_nodes: + startup_nodes = [] + + if "db" in kwargs: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "Argument 'db' is not possible to use in cluster mode" + ) + + # Get the startup node/s + if url: + url_options = parse_url(url) + if "path" in url_options: + raise RedisClusterException( + "RedisCluster does not currently support Unix Domain " + "Socket connections" + ) + if "db" in url_options and url_options["db"] != 0: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "A ``db`` querystring option can only be 0 in cluster mode" + ) + kwargs.update(url_options) + host = kwargs.get("host") + port = kwargs.get("port", port) + elif (not host or not port) and not startup_nodes: + # No startup node was provided + raise RedisClusterException( + "RedisCluster requires at least one node to discover the " + "cluster. Please provide one of the followings:\n" + "1. host and port, for example:\n" + " RedisCluster(host='localhost', port=6379)\n" + "2. list of startup nodes, for example:\n" + " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," + " ClusterNode('localhost', 6378)])" + ) + + # Update the connection arguments + # Whenever a new connection is established, RedisCluster's on_connect + # method should be run + kwargs["redis_connect_func"] = self.on_connect + self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs) + self.response_callbacks = kwargs[ + "response_callbacks" + ] = self.__class__.RESPONSE_CALLBACKS + if host and port: + startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) + + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.read_from_replicas = read_from_replicas + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps + self.nodes_manager = NodesManager( + startup_nodes=startup_nodes, + require_full_coverage=require_full_coverage, + **self.connection_kwargs, + ) + + self.result_callbacks = self.__class__.RESULT_CALLBACKS + self.result_callbacks[ + "CLUSTER SLOTS" + ] = lambda cmd, res, **kwargs: parse_cluster_slots( + list(res.values())[0], **kwargs + ) + self.commands_parser = CommandsParser() + self._initialize = True + self._lock = asyncio.Lock() + + async def initialize(self) -> "RedisCluster": + """Get all nodes from startup nodes & creates connections if not initialized.""" + if self._initialize: + async with self._lock: + if self._initialize: + self._initialize = False + try: + await self.nodes_manager.initialize() + await self.commands_parser.initialize( + self.nodes_manager.default_node + ) + except BaseException: + self._initialize = True + await self.nodes_manager.close() + await self.nodes_manager.close("startup_nodes") + raise + return self + + async def close(self) -> None: + """Close all connections & client if initialized.""" + if not self._initialize: + async with self._lock: + if not self._initialize: + self._initialize = True + await self.nodes_manager.close() + + async def __aenter__(self) -> "RedisCluster": + return await self.initialize() + + async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + await self.close() + + def __await__(self): + return self.initialize().__await__() + + _DEL_MESSAGE = "Unclosed RedisCluster client" + + def __del__(self, _warnings=warnings): + if hasattr(self, "_initialize") and not self._initialize: + _warnings.warn( + f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self + ) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + # TODO: Change to get_running_loop() when dropping support for py3.6 + asyncio.get_event_loop().call_exception_handler(context) + except RuntimeError: + ... + + async def on_connect(self, connection: Connection) -> None: + connection.set_parser(ClusterParser) + await connection.on_connect() + + if self.read_from_replicas: + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + await connection.send_command("READONLY") + if str_if_bytes(await connection.read_response_without_lock()) != "OK": + raise ConnectionError("READONLY command failed") + + def get_node( + self, + host: Optional[str] = None, + port: Optional[int] = None, + node_name: Optional[str] = None, + ) -> Optional["ClusterNode"]: + """Get node by (host, port) or node_name.""" + return self.nodes_manager.get_node(host, port, node_name) + + def get_primaries(self) -> List["ClusterNode"]: + """Get the primary nodes of the cluster.""" + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + + def get_replicas(self) -> List["ClusterNode"]: + """Get the replica nodes of the cluster.""" + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + + def get_random_node(self) -> "ClusterNode": + """Get a random node of the cluster.""" + return random.choice(list(self.nodes_manager.nodes_cache.values())) + + def get_nodes(self) -> List["ClusterNode"]: + """Get all nodes of the cluster.""" + return list(self.nodes_manager.nodes_cache.values()) + + def get_node_from_key( + self, key: str, replica: bool = False + ) -> Optional["ClusterNode"]: + """ + Get the cluster node corresponding to the provided key. + + :param key: + :param replica: + | Indicates if a replica should be returned + None will returned if no replica holds this key + + :raises SlotNotCoveredError: if the key is not covered by any slot. + """ + slot = self.keyslot(key) + slot_cache = self.nodes_manager.slots_cache.get(slot) + if not slot_cache: + raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') + if replica and len(self.nodes_manager.slots_cache[slot]) < 2: + return None + elif replica: + node_idx = 1 + else: + # primary + node_idx = 0 + + return slot_cache[node_idx] + + def get_default_node(self) -> "ClusterNode": + """Get the default node of the client.""" + return self.nodes_manager.default_node + + def set_default_node(self, node: "ClusterNode") -> None: + """ + Set the default node of the client. + + :raises DataError: if None is passed or node does not exist in cluster. + """ + if not node or not self.get_node(node_name=node.name): + raise DataError("The requested node does not exist in the cluster.") + + self.nodes_manager.default_node = node + + def set_response_callback(self, command: KeyT, callback: Callable) -> None: + """Set a custom response callback.""" + self.response_callbacks[command] = callback + + def get_encoder(self) -> Encoder: + """Get the encoder object of the client.""" + return self.encoder + + def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: + """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" + return self.connection_kwargs + + def keyslot(self, key: EncodableT) -> int: + """ + Find the keyslot for a given key. + + See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding + """ + k = self.encoder.encode(key) + return key_slot(k) + + async def _determine_nodes( + self, *args, node_flag: Optional[str] = None + ) -> List["ClusterNode"]: + command = args[0] + if not node_flag: + # get the nodes group for this command if it was predefined + node_flag = self.command_flags.get(command) + + if node_flag in self.node_flags: + if node_flag == self.__class__.DEFAULT_NODE: + # return the cluster's default node + return [self.nodes_manager.default_node] + if node_flag == self.__class__.PRIMARIES: + # return all primaries + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + if node_flag == self.__class__.REPLICAS: + # return all replicas + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + if node_flag == self.__class__.ALL_NODES: + # return all nodes + return list(self.nodes_manager.nodes_cache.values()) + if node_flag == self.__class__.RANDOM: + # return a random node + return [random.choice(list(self.nodes_manager.nodes_cache.values()))] + + # get the node that holds the key's slot + return [ + self.nodes_manager.get_node_from_slot( + await self._determine_slot(*args), + self.read_from_replicas and command in READ_COMMANDS, + ) + ] + + async def _determine_slot(self, *args) -> int: + command = args[0] + if self.command_flags.get(command) == SLOT_ID: + # The command contains the slot ID + return args[1] + + # Get the keys in the command + + # EVAL and EVALSHA are common enough that it's wasteful to go to the + # redis server to parse the keys. Besides, there is a bug in redis<7.0 + # where `self._get_command_keys()` fails anyway. So, we special case + # EVAL/EVALSHA. + # - issue: https://github.com/redis/redis/issues/9493 + # - fix: https://github.com/redis/redis/pull/9733 + if command in ("EVAL", "EVALSHA"): + # command syntax: EVAL "script body" num_keys ... + if len(args) <= 2: + raise RedisClusterException(f"Invalid args in command: {args}") + num_actual_keys = args[2] + eval_keys = args[3 : 3 + num_actual_keys] + # if there are 0 keys, that means the script can be run on any node + # so we can just return a random slot + if not eval_keys: + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + keys = eval_keys + else: + keys = await self.commands_parser.get_keys( + self.nodes_manager.default_node, *args + ) + if not keys: + # FCALL can call a function with 0 keys, that means the function + # can be run on any node so we can just return a random slot + if command in ("FCALL", "FCALL_RO"): + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + raise RedisClusterException( + "No way to dispatch this command to Redis Cluster. " + "Missing key.\nYou can execute the command by specifying " + f"target nodes.\nCommand: {args}" + ) + + # single key command + if len(keys) == 1: + return self.keyslot(keys[0]) + + # multi-key command; we need to make sure all keys are mapped to + # the same slot + slots = {self.keyslot(key) for key in keys} + if len(slots) != 1: + raise RedisClusterException( + f"{command} - all keys must map to the same key slot" + ) + + return slots.pop() + + def _is_node_flag( + self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str] + ) -> bool: + return isinstance(target_nodes, str) and target_nodes in self.node_flags + + def _parse_target_nodes( + self, target_nodes: Union[List["ClusterNode"], "ClusterNode"] + ) -> List["ClusterNode"]: + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = target_nodes.values() + else: + raise TypeError( + "target_nodes type can be one of the following: " + "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. " + f"The passed type is {type(target_nodes)}" + ) + return nodes + + async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any: + """ + Execute a raw command on the appropriate cluster node or target_nodes. + + It will retry the command as specified by :attr:`cluster_error_retry_attempts` & + then raise an exception. + + :param args: + | Raw command args + :param kwargs: + + - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` + or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] + - Rest of the kwargs are passed to the Redis connection + + :raises RedisClusterException: if target_nodes is not provided & the command + can't be mapped to a slot + """ + command = args[0] + target_nodes_specified = target_nodes = exception = None + retry_attempts = self.cluster_error_retry_attempts + + passed_targets = kwargs.pop("target_nodes", None) + if passed_targets and not self._is_node_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + target_nodes_specified = True + retry_attempts = 1 + + for _ in range(0, retry_attempts): + if self._initialize: + await self.initialize() + try: + if not target_nodes_specified: + # Determine the nodes to execute the command on + target_nodes = await self._determine_nodes( + *args, node_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {args} command on" + ) + + if len(target_nodes) == 1: + # Return the processed result + ret = await self._execute_command(target_nodes[0], *args, **kwargs) + if command in self.result_callbacks: + return self.result_callbacks[command]( + command, {target_nodes[0].name: ret}, **kwargs + ) + return ret + else: + keys = [node.name for node in target_nodes] + values = await asyncio.gather( + *( + asyncio.ensure_future( + self._execute_command(node, *args, **kwargs) + ) + for node in target_nodes + ) + ) + if command in self.result_callbacks: + return self.result_callbacks[command]( + command, dict(zip(keys, values)), **kwargs + ) + return dict(zip(keys, values)) + except BaseException as e: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: + # The nodes and slots cache were reinitialized. + # Try again with the new cluster setup. + exception = e + else: + # All other errors should be raised. + raise e + + # If it fails the configured number of times then raise exception back + # to caller of this method + raise exception + + async def _execute_command( + self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs + ) -> Any: + redirect_addr = asking = moved = None + ttl = self.RedisClusterRequestTTL + connection_error_retry_counter = 0 + + while ttl > 0: + ttl -= 1 + try: + if asking: + target_node = self.get_node(node_name=redirect_addr) + await target_node.execute_command("ASKING") + asking = False + elif moved: + # MOVED occurred and the slots cache was updated, + # refresh the target node + slot = await self._determine_slot(*args) + target_node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and args[0] in READ_COMMANDS + ) + moved = False + + return await target_node.execute_command(*args, **kwargs) + except BusyLoadingError: + raise + except (ConnectionError, TimeoutError): + # Give the node 0.25 seconds to get back up and retry again + # with same node and configuration. After 5 attempts then try + # to reinitialize the cluster and see if the nodes + # configuration has changed or not + connection_error_retry_counter += 1 + if connection_error_retry_counter < 5: + await asyncio.sleep(0.25) + else: + # Hard force of reinitialize of the node/slots setup + # and try again with the new setup + await self.close() + raise + except MovedError as e: + # First, we will try to patch the slots/nodes cache with the + # redirected node output and try again. If MovedError exceeds + # 'reinitialize_steps' number of times, we will force + # reinitializing the tables, and then try again. + # 'reinitialize_steps' counter will increase faster when + # the same client object is shared between multiple threads. To + # reduce the frequency you can set this variable in the + # RedisCluster constructor. + self.reinitialize_counter += 1 + if ( + self.reinitialize_steps + and self.reinitialize_counter % self.reinitialize_steps == 0 + ): + await self.close() + # Reset the counter + self.reinitialize_counter = 0 + else: + self.nodes_manager._moved_exception = e + moved = True + except TryAgainError: + if ttl < self.RedisClusterRequestTTL / 2: + await asyncio.sleep(0.05) + except AskError as e: + redirect_addr = get_node_name(host=e.host, port=e.port) + asking = True + except ClusterDownError: + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + await asyncio.sleep(0.25) + await self.close() + raise + + raise ClusterError("TTL exhausted.") + + +class ClusterNode: + """ + Create a new ClusterNode. + + Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` + objects for the (host, port). + """ + + __slots__ = ( + "_connections", + "_free", + "connection_class", + "connection_kwargs", + "host", + "max_connections", + "name", + "port", + "response_callbacks", + "server_type", + ) + + def __init__( + self, + host: str, + port: int, + server_type: Optional[str] = None, + max_connections: int = 2 ** 31, + connection_class: Type[Connection] = Connection, + response_callbacks: Dict = None, + **connection_kwargs, + ) -> None: + if host == "localhost": + host = socket.gethostbyname(host) + + connection_kwargs["host"] = host + connection_kwargs["port"] = port + self.host = host + self.port = port + self.name = get_node_name(host, port) + self.server_type = server_type + + self.max_connections = max_connections + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.response_callbacks = response_callbacks + + self._connections = [] + self._free = collections.deque(maxlen=self.max_connections) + + def __repr__(self) -> str: + return ( + f"[host={self.host}, port={self.port}, " + f"name={self.name}, server_type={self.server_type}]" + ) + + def __eq__(self, obj: "ClusterNode") -> bool: + return isinstance(obj, ClusterNode) and obj.name == self.name + + _DEL_MESSAGE = "Unclosed ClusterNode object" + + def __del__(self, _warnings=warnings): + for connection in self._connections: + if connection.is_connected: + _warnings.warn( + f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self + ) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + # TODO: Change to get_running_loop() when dropping support for py3.6 + asyncio.get_event_loop().call_exception_handler(context) + except RuntimeError: + ... + break + + async def disconnect(self) -> None: + ret = await asyncio.gather( + *( + asyncio.ensure_future(connection.disconnect()) + for connection in self._connections + ), + return_exceptions=True, + ) + exc = next((res for res in ret if isinstance(res, Exception)), None) + if exc: + raise exc + + async def execute_command(self, *args, **kwargs) -> Any: + # Acquire connection + connection = None + if self._free: + for _ in range(len(self._free)): + connection = self._free.popleft() + if connection.is_connected: + break + self._free.append(connection) + else: + connection = self._free.popleft() + else: + if len(self._connections) < self.max_connections: + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + else: + raise ConnectionError("Too many connections") + + # Execute command + command = connection.pack_command(*args) + await connection.send_packed_command(command, False) + try: + if NEVER_DECODE in kwargs: + response = await connection.read_response_without_lock( + disable_decoding=True + ) + else: + response = await connection.read_response_without_lock() + except ResponseError: + if EMPTY_RESPONSE in kwargs: + return kwargs[EMPTY_RESPONSE] + raise + finally: + # Release connection + self._free.append(connection) + + # Return response + try: + return self.response_callbacks[args[0]](response, **kwargs) + except KeyError: + return response + + +class NodesManager: + __slots__ = ( + "_moved_exception", + "_require_full_coverage", + "connection_kwargs", + "default_node", + "nodes_cache", + "read_load_balancer", + "slots_cache", + "startup_nodes", + ) + + def __init__( + self, + startup_nodes: List["ClusterNode"], + require_full_coverage: bool = False, + **kwargs, + ) -> None: + self.nodes_cache = {} + self.slots_cache = {} + self.startup_nodes = {node.name: node for node in startup_nodes} + self.default_node = None + self._require_full_coverage = require_full_coverage + self._moved_exception = None + self.connection_kwargs = kwargs + self.read_load_balancer = LoadBalancer() + + def get_node( + self, + host: Optional[str] = None, + port: Optional[int] = None, + node_name: Optional[str] = None, + ) -> "ClusterNode": + if host and port: + # the user passed host and port + if host == "localhost": + host = socket.gethostbyname(host) + return self.nodes_cache.get(get_node_name(host=host, port=port)) + elif node_name: + return self.nodes_cache.get(node_name) + else: + raise DataError( + "get_node requires one of the following: " + "1. node name " + "2. host and port" + ) + + def set_nodes( + self, + old: Dict[str, "ClusterNode"], + new: Dict[str, "ClusterNode"], + remove_old=False, + ) -> None: + tasks = [] + if remove_old: + tasks = [ + asyncio.ensure_future(node.disconnect()) + for name, node in old.items() + if name not in new + ] + for name, node in new.items(): + if name in old: + if old[name] is node: + continue + tasks.append(asyncio.ensure_future(old[name].disconnect())) + old[name] = node + + def _update_moved_slots(self) -> None: + e = self._moved_exception + redirected_node = self.get_node(host=e.host, port=e.port) + if redirected_node: + # The node already exists + if redirected_node.server_type != PRIMARY: + # Update the node's server type + redirected_node.server_type = PRIMARY + else: + # This is a new node, we will add it to the nodes cache + redirected_node = ClusterNode( + e.host, e.port, PRIMARY, **self.connection_kwargs + ) + self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) + if redirected_node in self.slots_cache[e.slot_id]: + # The MOVED error resulted from a failover, and the new slot owner + # had previously been a replica. + old_primary = self.slots_cache[e.slot_id][0] + # Update the old primary to be a replica and add it to the end of + # the slot's node list + old_primary.server_type = REPLICA + self.slots_cache[e.slot_id].append(old_primary) + # Remove the old replica, which is now a primary, from the slot's + # node list + self.slots_cache[e.slot_id].remove(redirected_node) + # Override the old primary with the new one + self.slots_cache[e.slot_id][0] = redirected_node + if self.default_node == old_primary: + # Update the default node with the new primary + self.default_node = redirected_node + else: + # The new slot owner is a new server, or a server from a different + # shard. We need to remove all current nodes from the slot's list + # (including replications) and add just the new node. + self.slots_cache[e.slot_id] = [redirected_node] + # Reset moved_exception + self._moved_exception = None + + def get_node_from_slot( + self, slot: int, read_from_replicas: bool = False + ) -> "ClusterNode": + if self._moved_exception: + self._update_moved_slots() + + try: + if read_from_replicas: + # get the server index in a Round-Robin manner + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot]) + ) + return self.slots_cache[slot][node_idx] + return self.slots_cache[slot][0] + except (IndexError, TypeError): + raise SlotNotCoveredError( + f'Slot "{slot}" not covered by the cluster. ' + f'"require_full_coverage={self._require_full_coverage}"' + ) + + def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: + return [ + node + for node in self.nodes_cache.values() + if node.server_type == server_type + ] + + async def initialize(self) -> None: + self.read_load_balancer.reset() + tmp_nodes_cache = {} + tmp_slots = {} + disagreements = [] + startup_nodes_reachable = False + fully_covered = False + for startup_node in self.startup_nodes.values(): + try: + # Make sure cluster mode is enabled on this node + if not (await startup_node.execute_command("INFO")).get( + "cluster_enabled" + ): + raise RedisClusterException( + "Cluster mode is not enabled on this node" + ) + cluster_slots = str_if_bytes( + await startup_node.execute_command("CLUSTER SLOTS") + ) + startup_nodes_reachable = True + except (ConnectionError, TimeoutError): + continue + except ResponseError as e: + # Isn't a cluster connection, so it won't parse these + # exceptions automatically + message = e.__str__() + if "CLUSTERDOWN" in message or "MASTERDOWN" in message: + continue + else: + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + f"server: {startup_node}. error: {message}" + ) + except Exception as e: + message = e.__str__() + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + f"server {startup_node.name}. error: {message}" + ) + + # CLUSTER SLOTS command results in the following output: + # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] + # where each node contains the following list: [IP, port, node_id] + # Therefore, cluster_slots[0][2][0] will be the IP address of the + # primary node of the first slot section. + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if ( + len(cluster_slots) == 1 + and not cluster_slots[0][2][0] + and len(self.startup_nodes) == 1 + ): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + for i in range(2, len(slot)): + slot[i] = [str_if_bytes(val) for val in slot[i]] + primary_node = slot[2] + host = primary_node[0] + if host == "": + host = startup_node.host + port = int(primary_node[1]) + + target_node = tmp_nodes_cache.get(get_node_name(host, port)) + if not target_node: + target_node = ClusterNode( + host, port, PRIMARY, **self.connection_kwargs + ) + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + + target_replica_node = tmp_nodes_cache.get( + get_node_name(host, port) + ) + if not target_replica_node: + target_replica_node = ClusterNode( + host, port, REPLICA, **self.connection_kwargs + ) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node + else: + # Validate that 2 nodes want to use the same slot cache + # setup + tmp_slot = tmp_slots[i][0] + if tmp_slot.name != target_node.name: + disagreements.append( + f"{tmp_slot.name} vs {target_node.name} on slot: {i}" + ) + + if len(disagreements) > 5: + raise RedisClusterException( + f"startup_nodes could not agree on a valid " + f'slots cache: {", ".join(disagreements)}' + ) + + # Validate if all slots are covered or if we should try next startup node + fully_covered = True + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + if i not in tmp_slots: + fully_covered = False + break + if fully_covered: + break + + if not startup_nodes_reachable: + raise RedisClusterException( + "Redis Cluster cannot be connected. Please provide at least " + "one reachable node. " + ) + + # Check if the slots are not fully covered + if not fully_covered and self._require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + f"All slots are not covered after query all startup_nodes. " + f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " + f"covered..." + ) + + # Set the tmp variables to the real variables + self.slots_cache = tmp_slots + self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) + # Populate the startup nodes with all discovered nodes + self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) + + # Set the default node + self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] + # If initialize was called after a MovedError, clear it + self._moved_exception = None + + async def close(self, attr: str = "nodes_cache") -> None: + self.default_node = None + await asyncio.gather( + *( + asyncio.ensure_future(node.disconnect()) + for node in getattr(self, attr).values() + ) + ) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 91961ba..9de2d46 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -24,7 +24,6 @@ from typing import ( Type, TypeVar, Union, - cast, ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse @@ -110,32 +109,32 @@ class Encoder: def encode(self, value: EncodableT) -> EncodedT: """Return a bytestring or bytes-like representation of the value""" + if isinstance(value, str): + return value.encode(self.encoding, self.encoding_errors) if isinstance(value, (bytes, memoryview)): return value - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) if isinstance(value, (int, float)): + if isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. " + "Convert to a bytes, string, int or float first." + ) return repr(value).encode() - if not isinstance(value, str): - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ # type: ignore[unreachable] - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - return value.encode(self.encoding, self.encoding_errors) + # a value we don't know how to deal with. throw an error + typename = value.__class__.__name__ + raise DataError( + f"Invalid input of type: {typename!r}. " + "Convert to a bytes, string, int or float first." + ) def decode(self, value: EncodableT, force=False) -> EncodableT: """Return a unicode string from the bytes-like representation""" if self.decode_responses or force: - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) if isinstance(value, bytes): return value.decode(self.encoding, self.encoding_errors) + if isinstance(value, memoryview): + return value.tobytes().decode(self.encoding, self.encoding_errors) return value @@ -336,7 +335,7 @@ class SocketBuffer: def close(self): try: self.purge() - self._buffer.close() # type: ignore[union-attr] + self._buffer.close() except Exception: # issue #633 suggests the purge/close somehow raised a # BadFileDescriptor error. Perhaps the client ran out of @@ -466,7 +465,7 @@ class HiredisParser(BaseParser): self._next_response = False async def can_read(self, timeout: float): - if not self._reader: + if not self._stream or not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._next_response is False: @@ -480,14 +479,14 @@ class HiredisParser(BaseParser): timeout: Union[float, None, _Sentinel] = SENTINEL, raise_on_timeout: bool = True, ): - if self._stream is None or self._reader is None: - raise RedisError("Parser already closed.") - timeout = self._socket_timeout if timeout is SENTINEL else timeout try: - async with async_timeout.timeout(timeout): + if timeout is None: buffer = await self._stream.read(self._read_size) - if not isinstance(buffer, bytes) or len(buffer) == 0: + else: + async with async_timeout.timeout(timeout): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None self._reader.feed(buffer) # data was read from the socket and added to the buffer. @@ -516,9 +515,6 @@ class HiredisParser(BaseParser): self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - response: Union[ - EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] - ] # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response @@ -541,8 +537,7 @@ class HiredisParser(BaseParser): and isinstance(response[0], ConnectionError) ): raise response[0] - # cast as there won't be a ConnectionError here. - return cast(Union[EncodableT, List[EncodableT]], response) + return response DefaultParser: Type[Union[PythonParser, HiredisParser]] @@ -637,7 +632,7 @@ class Connection: self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout if retry_on_timeout: - if retry is None: + if not retry: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable @@ -681,7 +676,7 @@ class Connection: @property def is_connected(self): - return bool(self._reader and self._writer) + return self._reader and self._writer def register_connect_callback(self, callback): self._connect_callbacks.append(weakref.WeakMethod(callback)) @@ -713,7 +708,7 @@ class Connection: raise ConnectionError(exc) from exc try: - if self.redis_connect_func is None: + if not self.redis_connect_func: # Use the default on_connect function await self.on_connect() else: @@ -745,7 +740,7 @@ class Connection: self._reader = reader self._writer = writer sock = writer.transport.get_extra_info("socket") - if sock is not None: + if sock: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: # TCP_KEEPALIVE @@ -856,32 +851,29 @@ class Connection: await self.retry.call_with_retry(self._send_ping, self._ping_failed) async def _send_packed_command(self, command: Iterable[bytes]) -> None: - if self._writer is None: - raise RedisError("Connection already closed.") - self._writer.writelines(command) await self._writer.drain() async def send_packed_command( - self, - command: Union[bytes, str, Iterable[bytes]], - check_health: bool = True, - ): - """Send an already packed command to the Redis server""" - if not self._writer: + self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True + ) -> None: + if not self.is_connected: await self.connect() - # guard against health check recursion - if check_health: + elif check_health: await self.check_health() + try: if isinstance(command, str): command = command.encode() if isinstance(command, bytes): command = [command] - await asyncio.wait_for( - self._send_packed_command(command), - self.socket_timeout, - ) + if self.socket_timeout: + await asyncio.wait_for( + self._send_packed_command(command), self.socket_timeout + ) + else: + self._writer.writelines(command) + await self._writer.drain() except asyncio.TimeoutError: await self.disconnect() raise TimeoutError("Timeout writing to socket") from None @@ -901,8 +893,6 @@ class Connection: async def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" - if not self.is_connected: - await self.connect() await self.send_packed_command( self.pack_command(*args), check_health=kwargs.get("check_health", True) ) @@ -923,10 +913,50 @@ class Connection: """Read the response from a previously sent command""" try: async with self._lock: + if self.socket_timeout: + async with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + else: + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + except OSError as e: + await self.disconnect() + raise ConnectionError( + f"Error while reading from {self.host}:{self.port} : {e.args}" + ) + except BaseException: + await self.disconnect() + raise + + if self.health_check_interval: + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + self.next_health_check = func().time() + self.health_check_interval + + if isinstance(response, ResponseError): + raise response from None + return response + + async def read_response_without_lock(self, disable_decoding: bool = False): + """Read the response from a previously sent command""" + try: + if self.socket_timeout: async with async_timeout.timeout(self.socket_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) + else: + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) except asyncio.TimeoutError: await self.disconnect() raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") @@ -1182,10 +1212,7 @@ class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] self._lock = asyncio.Lock() def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: - pieces = [ - ("path", self.path), - ("db", self.db), - ] + pieces = [("path", self.path), ("db", self.db)] if self.client_name: pieces.append(("client_name", self.client_name)) return pieces @@ -1254,12 +1281,11 @@ def parse_url(url: str) -> ConnectKwargs: parser = URL_QUERY_ARGUMENT_PARSERS.get(name) if parser: try: - # We can't type this. - kwargs[name] = parser(value) # type: ignore[misc] + kwargs[name] = parser(value) except (TypeError, ValueError): raise ValueError(f"Invalid value for `{name}` in connection URL.") else: - kwargs[name] = value # type: ignore[misc] + kwargs[name] = value if parsed.username: kwargs["username"] = unquote(parsed.username) diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py new file mode 100644 index 0000000..273fe03 --- /dev/null +++ b/redis/asyncio/parser.py @@ -0,0 +1,95 @@ +from typing import TYPE_CHECKING, List, Optional, Union + +from redis.exceptions import RedisError, ResponseError + +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + + +class CommandsParser: + """ + Parses Redis commands to get command keys. + + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with 'movablekeys', + and these commands' keys are determined by the command 'COMMAND GETKEYS'. + + NOTE: Due to a bug in redis<7.0, this does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this with EVAL or EVALSHA. + """ + + __slots__ = ("commands",) + + def __init__(self) -> None: + self.commands = {} + + async def initialize(self, r: "ClusterNode") -> None: + commands = await r.execute_command("COMMAND") + for cmd, command in commands.items(): + if "movablekeys" in command["flags"]: + commands[cmd] = -1 + elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: + commands[cmd] = 0 + elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: + commands[cmd] = 1 + self.commands = {cmd.upper(): command for cmd, command in commands.items()} + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + async def get_keys( + self, redis_conn: "ClusterNode", *args + ) -> Optional[Union[List[str], List[bytes]]]: + if len(args) < 2: + # The command has no keys in it + return None + + try: + command = self.commands[args[0]] + except KeyError: + # try to split the command name and to take only the main command + # e.g. 'memory' for 'memory usage' + args = args[0].split() + list(args[1:]) + cmd_name = args[0] + if cmd_name not in self.commands: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + await self.initialize(redis_conn) + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name.upper()} command doesn't exist in Redis commands" + ) + + command = self.commands[cmd_name] + + if command == 1: + return [args[1]] + if command == 0: + return None + if command == -1: + return await self._get_moveable_keys(redis_conn, *args) + + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) + last_key_pos + return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] + + async def _get_moveable_keys( + self, redis_conn: "ClusterNode", *args + ) -> Optional[List[str]]: + try: + keys = await redis_conn.execute_command("COMMAND GETKEYS", *args) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e + return keys diff --git a/redis/client.py b/redis/client.py index 87c7991..7c83b61 100755 --- a/redis/client.py +++ b/redis/client.py @@ -410,11 +410,7 @@ def parse_slowlog_get(response, **options): space = " " if options.get("decode_responses", False) else b" " def parse_item(item): - result = { - "id": item[0], - "start_time": int(item[1]), - "duration": int(item[2]), - } + result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} # Redis Enterprise injects another entry at index [3], which has # the complexity info (i.e. the value N in case the command has # an O(N) complexity) instead of the command. @@ -703,7 +699,7 @@ class AbstractRedis: **string_keys_to_dict("SORT", sort_return_tuples), **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), **string_keys_to_dict( - "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READONLY READWRITE " "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", bool_ok, ), @@ -753,17 +749,18 @@ class AbstractRedis: "CLUSTER DELSLOTSRANGE": bool_ok, "CLUSTER FAILOVER": bool_ok, "CLUSTER FORGET": bool_ok, + "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), "CLUSTER INFO": parse_cluster_info, "CLUSTER KEYSLOT": lambda x: int(x), "CLUSTER MEET": bool_ok, "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICAS": parse_cluster_nodes, "CLUSTER REPLICATE": bool_ok, "CLUSTER RESET": bool_ok, "CLUSTER SAVECONFIG": bool_ok, "CLUSTER SET-CONFIG-EPOCH": bool_ok, "CLUSTER SETSLOT": bool_ok, "CLUSTER SLAVES": parse_cluster_nodes, - "CLUSTER REPLICAS": parse_cluster_nodes, "COMMAND": parse_command, "COMMAND COUNT": int, "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), @@ -1035,11 +1032,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): """Set a custom Response Callback""" self.response_callbacks[command] = callback - def load_external_module( - self, - funcname, - func, - ): + def load_external_module(self, funcname, func): """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. diff --git a/redis/cluster.py b/redis/cluster.py index bf7ac20..d42d49b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -128,10 +128,7 @@ REDIS_ALLOWED_KEYS = ( "unix_socket_path", "username", ) -KWARGS_DISABLED_KEYS = ( - "host", - "port", -) +KWARGS_DISABLED_KEYS = ("host", "port") # Not complete, but covers the major ones # https://redis.io/commands @@ -207,7 +204,7 @@ class ClusterParser(DefaultParser): ) -class RedisCluster(RedisClusterCommands): +class AbstractRedisCluster: RedisClusterRequestTTL = 16 PRIMARIES = "primaries" @@ -308,10 +305,7 @@ class RedisCluster(RedisClusterCommands): ], PRIMARIES, ), - list_keys_to_dict( - ["FUNCTION DUMP"], - RANDOM, - ), + list_keys_to_dict(["FUNCTION DUMP"], RANDOM), list_keys_to_dict( [ "CLUSTER COUNTKEYSINSLOT", @@ -360,49 +354,14 @@ class RedisCluster(RedisClusterCommands): ], ) - CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { - "CLUSTER ADDSLOTS": bool, - "CLUSTER ADDSLOTSRANGE": bool, - "CLUSTER COUNT-FAILURE-REPORTS": int, - "CLUSTER COUNTKEYSINSLOT": int, - "CLUSTER DELSLOTS": bool, - "CLUSTER DELSLOTSRANGE": bool, - "CLUSTER FAILOVER": bool, - "CLUSTER FORGET": bool, - "CLUSTER GETKEYSINSLOT": list, - "CLUSTER KEYSLOT": int, - "CLUSTER MEET": bool, - "CLUSTER REPLICATE": bool, - "CLUSTER RESET": bool, - "CLUSTER SAVECONFIG": bool, - "CLUSTER SET-CONFIG-EPOCH": bool, - "CLUSTER SETSLOT": bool, - "CLUSTER SLOTS": parse_cluster_slots, - "ASKING": bool, - "READONLY": bool, - "READWRITE": bool, - } + CLUSTER_COMMANDS_RESPONSE_CALLBACKS = {"CLUSTER SLOTS": parse_cluster_slots} RESULT_CALLBACKS = dict_merge( + list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), list_keys_to_dict( - [ - "PUBSUB NUMSUB", - ], - parse_pubsub_numsub, - ), - list_keys_to_dict( - [ - "PUBSUB NUMPAT", - ], - lambda command, res: sum(list(res.values())), - ), - list_keys_to_dict( - [ - "KEYS", - "PUBSUB CHANNELS", - ], - merge_result, + ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) ), + list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result), list_keys_to_dict( [ "PING", @@ -420,49 +379,69 @@ class RedisCluster(RedisClusterCommands): lambda command, res: all(res.values()) if isinstance(res, dict) else res, ), list_keys_to_dict( - [ - "DBSIZE", - "WAIT", - ], + ["DBSIZE", "WAIT"], lambda command, res: sum(res.values()) if isinstance(res, dict) else res, ), list_keys_to_dict( - [ - "CLIENT UNBLOCK", - ], - lambda command, res: 1 if sum(res.values()) > 0 else 0, - ), - list_keys_to_dict( - [ - "SCAN", - ], - parse_scan_result, - ), - list_keys_to_dict( - [ - "SCRIPT LOAD", - ], - lambda command, res: list(res.values()).pop(), + ["CLIENT UNBLOCK"], lambda command, res: 1 if sum(res.values()) > 0 else 0 ), + list_keys_to_dict(["SCAN"], parse_scan_result), list_keys_to_dict( - [ - "SCRIPT EXISTS", - ], - lambda command, res: [all(k) for k in zip(*res.values())], + ["SCRIPT LOAD"], lambda command, res: list(res.values()).pop() ), list_keys_to_dict( - [ - "SCRIPT FLUSH", - ], - lambda command, res: all(res.values()), + ["SCRIPT EXISTS"], lambda command, res: [all(k) for k in zip(*res.values())] ), + list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), ) - ERRORS_ALLOW_RETRY = ( - ConnectionError, - TimeoutError, - ClusterDownError, - ) + ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + + +class RedisCluster(AbstractRedisCluster, RedisClusterCommands): + @classmethod + def from_url(cls, url, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/redis> + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/rediss> + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + return cls(url=url, **kwargs) def __init__( self, @@ -617,50 +596,6 @@ class RedisCluster(RedisClusterCommands): # Client was already disconnected. do nothing pass - @classmethod - def from_url(cls, url, **kwargs): - """ - Return a Redis client object configured from the given URL - - For example:: - - redis://[[username]:[password]]@localhost:6379/0 - rediss://[[username]:[password]]@localhost:6379/0 - unix://[[username]:[password]]@/path/to/socket.sock?db=0 - - Three URL schemes are supported: - - - `redis://` creates a TCP socket connection. See more at: - <https://www.iana.org/assignments/uri-schemes/prov/redis> - - `rediss://` creates a SSL wrapped TCP socket connection. See more at: - <https://www.iana.org/assignments/uri-schemes/prov/rediss> - - ``unix://``: creates a Unix Domain Socket connection. - - The username, password, hostname, path and all querystring values - are passed through urllib.parse.unquote in order to replace any - percent-encoded values with their corresponding characters. - - There are several ways to specify a database number. The first value - found will be used: - - 1. A ``db`` querystring option, e.g. redis://localhost?db=0 - 2. If using the redis:// or rediss:// schemes, the path argument - of the url, e.g. redis://localhost/0 - 3. A ``db`` keyword argument to this function. - - If none of these options are specified, the default db=0 is used. - - All querystring options are cast to their appropriate Python types. - Boolean arguments can be specified with string values "True"/"False" - or "Yes"/"No". Values that cannot be properly cast cause a - ``ValueError`` to be raised. Once parsed, the querystring arguments - and keyword arguments are passed to the ``ConnectionPool``'s - class initializer. In the case of conflicting arguments, querystring - arguments always win. - - """ - return cls(url=url, **kwargs) - def on_connect(self, connection): """ Initialize the connection, authenticate and select a database and send @@ -996,9 +931,6 @@ class RedisCluster(RedisClusterCommands): return slots.pop() - def reinitialize_caches(self): - self.nodes_manager.initialize() - def get_encoder(self): """ Get the connections' encoder @@ -1085,7 +1017,7 @@ class RedisCluster(RedisClusterCommands): # Return the processed result return self._process_result(args[0], res, **kwargs) except BaseException as e: - if type(e) in RedisCluster.ERRORS_ALLOW_RETRY: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. exception = e @@ -1246,11 +1178,7 @@ class RedisCluster(RedisClusterCommands): else: return res - def load_external_module( - self, - funcname, - func, - ): + def load_external_module(self, funcname, func): """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -1464,9 +1392,7 @@ class NodesManager: for node in nodes: if node.redis_connection is None: node.redis_connection = self.create_redis_node( - host=node.host, - port=node.port, - **self.connection_kwargs, + host=node.host, port=node.port, **self.connection_kwargs ) def create_redis_node(self, host, port, **kwargs): diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index de21a9e..e3383ff 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,4 +1,4 @@ -from .cluster import RedisClusterCommands +from .cluster import AsyncRedisClusterCommands, RedisClusterCommands from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args from .parser import CommandsParser @@ -6,6 +6,7 @@ from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands from .sentinel import AsyncSentinelCommands, SentinelCommands __all__ = [ + "AsyncRedisClusterCommands", "RedisClusterCommands", "CommandsParser", "AsyncCoreCommands", diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index aaddb6a..06b702f 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -1,27 +1,57 @@ -from typing import Iterator, Union +import asyncio +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterable, + Iterator, + List, + Mapping, + NoReturn, + Optional, + Union, +) +from redis.compat import Literal from redis.crc import key_slot from redis.exceptions import RedisClusterException, RedisError -from redis.typing import PatternT +from redis.typing import ( + AnyKeyT, + ClusterCommandsProtocol, + EncodableT, + KeysT, + KeyT, + PatternT, +) from .core import ( ACLCommands, + AsyncACLCommands, + AsyncDataAccessCommands, + AsyncFunctionCommands, + AsyncManagementCommands, + AsyncScriptCommands, DataAccessCommands, FunctionCommands, ManagementCommands, PubSubCommands, + ResponseT, ScriptCommands, ) from .helpers import list_or_args from .redismodules import RedisModuleCommands +if TYPE_CHECKING: + from redis.asyncio.cluster import TargetNodesT + -class ClusterMultiKeyCommands: +class ClusterMultiKeyCommands(ClusterCommandsProtocol): """ A class containing commands that handle more than one key """ - def _partition_keys_by_slot(self, keys): + def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]: """ Split keys into a dictionary that maps a slot to a list of keys. @@ -34,7 +64,7 @@ class ClusterMultiKeyCommands: return slots_to_keys - def mget_nonatomic(self, keys, *args): + def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: """ Splits the keys into different slots and then calls MGET for the keys of every slot. This operation will not be atomic @@ -70,7 +100,7 @@ class ClusterMultiKeyCommands: vals_in_order = [all_results[key] for key in keys] return vals_in_order - def mset_nonatomic(self, mapping): + def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -99,7 +129,7 @@ class ClusterMultiKeyCommands: return res - def _split_command_across_slots(self, command, *keys): + def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: """ Runs the given command once for the keys of each slot. Returns the sum of the return values. @@ -114,7 +144,7 @@ class ClusterMultiKeyCommands: return total - def exists(self, *keys): + def exists(self, *keys: KeyT) -> ResponseT: """ Returns the number of ``names`` that exist in the whole cluster. The keys are first split up into slots @@ -124,7 +154,7 @@ class ClusterMultiKeyCommands: """ return self._split_command_across_slots("EXISTS", *keys) - def delete(self, *keys): + def delete(self, *keys: KeyT) -> ResponseT: """ Deletes the given keys in the cluster. The keys are first split up into slots @@ -137,7 +167,7 @@ class ClusterMultiKeyCommands: """ return self._split_command_across_slots("DEL", *keys) - def touch(self, *keys): + def touch(self, *keys: KeyT) -> ResponseT: """ Updates the last access time of given keys across the cluster. @@ -152,7 +182,7 @@ class ClusterMultiKeyCommands: """ return self._split_command_across_slots("TOUCH", *keys) - def unlink(self, *keys): + def unlink(self, *keys: KeyT) -> ResponseT: """ Remove the specified keys in a different thread. @@ -167,160 +197,135 @@ class ClusterMultiKeyCommands: return self._split_command_across_slots("UNLINK", *keys) -class ClusterManagementCommands(ManagementCommands): +class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands): """ - A class for Redis Cluster management commands - - The class inherits from Redis's core ManagementCommands class and do the - required adjustments to work with cluster mode + A class containing commands that handle more than one key """ - def slaveof(self, *args, **kwargs): + async def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: """ - Make the server a replica of another instance, or promote it as master. + Splits the keys into different slots and then calls MGET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. - For more information see https://redis.io/commands/slaveof - """ - raise RedisClusterException("SLAVEOF is not supported in cluster mode") + Returns a list of values ordered identically to ``keys`` - def replicaof(self, *args, **kwargs): + For more information see https://redis.io/commands/mget """ - Make the server a replica of another instance, or promote it as master. - For more information see https://redis.io/commands/replicaof - """ - raise RedisClusterException("REPLICAOF is not supported in cluster mode") + from redis.client import EMPTY_RESPONSE - def swapdb(self, *args, **kwargs): - """ - Swaps two Redis databases. + options = {} + if not args: + options[EMPTY_RESPONSE] = [] - For more information see https://redis.io/commands/swapdb - """ - raise RedisClusterException("SWAPDB is not supported in cluster mode") + # Concatenate all keys into a list + keys = list_or_args(keys, args) + # Split keys into slots + slots_to_keys = self._partition_keys_by_slot(keys) + # Call MGET for every slot and concatenate + # the results + # We must make sure that the keys are returned in order + all_values = await asyncio.gather( + *( + asyncio.ensure_future( + self.execute_command("MGET", *slot_keys, **options) + ) + for slot_keys in slots_to_keys.values() + ) + ) -class ClusterDataAccessCommands(DataAccessCommands): - """ - A class for Redis Cluster Data Access Commands + all_results = {} + for slot_keys, slot_values in zip(slots_to_keys.values(), all_values): + all_results.update(dict(zip(slot_keys, slot_values))) - The class inherits from Redis's core DataAccessCommand class and do the - required adjustments to work with cluster mode - """ + # Sort the results + vals_in_order = [all_results[key] for key in keys] + return vals_in_order - def stralgo( - self, - algo, - value1, - value2, - specific_argument="strings", - len=False, - idx=False, - minmatchlen=None, - withmatchlen=False, - **kwargs, - ): + async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: """ - Implements complex algorithms that operate on strings. - Right now the only algorithm implemented is the LCS algorithm - (longest common substring). However new algorithms could be - implemented in the future. + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). - ``algo`` Right now must be LCS - ``value1`` and ``value2`` Can be two strings or two keys - ``specific_argument`` Specifying if the arguments to the algorithm - will be keys or strings. strings is the default. - ``len`` Returns just the len of the match. - ``idx`` Returns the match positions in each string. - ``minmatchlen`` Restrict the list of matches to the ones of a given - minimal length. Can be provided only when ``idx`` set to True. - ``withmatchlen`` Returns the matches with the len of the match. - Can be provided only when ``idx`` set to True. + Splits the keys into different slots and then calls MSET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. - For more information see https://redis.io/commands/stralgo + For more information see https://redis.io/commands/mset """ - target_nodes = kwargs.pop("target_nodes", None) - if specific_argument == "strings" and target_nodes is None: - target_nodes = "default-node" - kwargs.update({"target_nodes": target_nodes}) - return super().stralgo( - algo, - value1, - value2, - specific_argument, - len, - idx, - minmatchlen, - withmatchlen, - **kwargs, - ) - def scan_iter( - self, - match: Union[PatternT, None] = None, - count: Union[int, None] = None, - _type: Union[str, None] = None, - **kwargs, - ) -> Iterator: - # Do the first query with cursor=0 for all nodes - cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs) - yield from data + # Partition the keys by slot + slots_to_pairs = {} + for pair in mapping.items(): + # encode the key + k = self.encoder.encode(pair[0]) + slot = key_slot(k) + slots_to_pairs.setdefault(slot, []).extend(pair) - cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} - if cursors: - # Get nodes by name - nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + # Call MSET for every slot and concatenate + # the results (one result per slot) + return await asyncio.gather( + *( + asyncio.ensure_future(self.execute_command("MSET", *pairs)) + for pairs in slots_to_pairs.values() + ) + ) - # Iterate over each node till its cursor is 0 - kwargs.pop("target_nodes", None) - while cursors: - for name, cursor in cursors.items(): - cur, data = self.scan( - cursor=cursor, - match=match, - count=count, - _type=_type, - target_nodes=nodes[name], - **kwargs, - ) - yield from data - cursors[name] = cur[name] + async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: + """ + Runs the given command once for the keys + of each slot. Returns the sum of the return values. + """ + # Partition the keys by slot + slots_to_keys = self._partition_keys_by_slot(keys) - cursors = { - name: cursor for name, cursor in cursors.items() if cursor != 0 - } + # Sum up the reply from each command + return sum( + await asyncio.gather( + *( + asyncio.ensure_future(self.execute_command(command, *slot_keys)) + for slot_keys in slots_to_keys.values() + ) + ) + ) -class RedisClusterCommands( - ClusterMultiKeyCommands, - ClusterManagementCommands, - ACLCommands, - PubSubCommands, - ClusterDataAccessCommands, - ScriptCommands, - FunctionCommands, - RedisModuleCommands, -): +class ClusterManagementCommands(ManagementCommands): """ - A class for all Redis Cluster commands + A class for Redis Cluster management commands - For key-based commands, the target node(s) will be internally determined - by the keys' hash slot. - Non-key-based commands can be executed with the 'target_nodes' argument to - target specific nodes. By default, if target_nodes is not specified, the - command will be executed on the default cluster node. + The class inherits from Redis's core ManagementCommands class and do the + required adjustments to work with cluster mode + """ - :param :target_nodes: type can be one of the followings: - - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM - - 'ClusterNode' - - 'list(ClusterNodes)' - - 'dict(any:clusterNodes)' + def slaveof(self, *args, **kwargs) -> NoReturn: + """ + Make the server a replica of another instance, or promote it as master. - for example: - r.cluster_info(target_nodes=RedisCluster.ALL_NODES) - """ + For more information see https://redis.io/commands/slaveof + """ + raise RedisClusterException("SLAVEOF is not supported in cluster mode") - def cluster_myid(self, target_node): + def replicaof(self, *args, **kwargs) -> NoReturn: + """ + Make the server a replica of another instance, or promote it as master. + + For more information see https://redis.io/commands/replicaof + """ + raise RedisClusterException("REPLICAOF is not supported in cluster mode") + + def swapdb(self, *args, **kwargs) -> NoReturn: + """ + Swaps two Redis databases. + + For more information see https://redis.io/commands/swapdb + """ + raise RedisClusterException("SWAPDB is not supported in cluster mode") + + def cluster_myid(self, target_node: "TargetNodesT") -> ResponseT: """ Returns the node's id. @@ -331,7 +336,9 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER MYID", target_nodes=target_node) - def cluster_addslots(self, target_node, *slots): + def cluster_addslots( + self, target_node: "TargetNodesT", *slots: EncodableT + ) -> ResponseT: """ Assign new hash slots to receiving node. Sends to specified node. @@ -344,7 +351,9 @@ class RedisClusterCommands( "CLUSTER ADDSLOTS", *slots, target_nodes=target_node ) - def cluster_addslotsrange(self, target_node, *slots): + def cluster_addslotsrange( + self, target_node: "TargetNodesT", *slots: EncodableT + ) -> ResponseT: """ Similar to the CLUSTER ADDSLOTS command. The difference between the two commands is that ADDSLOTS takes a list of slots @@ -360,7 +369,7 @@ class RedisClusterCommands( "CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node ) - def cluster_countkeysinslot(self, slot_id): + def cluster_countkeysinslot(self, slot_id: int) -> ResponseT: """ Return the number of local keys in the specified hash slot Send to node based on specified slot_id @@ -369,7 +378,7 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) - def cluster_count_failure_report(self, node_id): + def cluster_count_failure_report(self, node_id: str) -> ResponseT: """ Return the number of failure reports active for a given node Sends to a random node @@ -378,7 +387,7 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) - def cluster_delslots(self, *slots): + def cluster_delslots(self, *slots: EncodableT) -> List[bool]: """ Set hash slots as unbound in the cluster. It determines by it self what node the slot is in and sends it there @@ -389,7 +398,7 @@ class RedisClusterCommands( """ return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] - def cluster_delslotsrange(self, *slots): + def cluster_delslotsrange(self, *slots: EncodableT) -> ResponseT: """ Similar to the CLUSTER DELSLOTS command. The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove @@ -400,7 +409,9 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER DELSLOTSRANGE", *slots) - def cluster_failover(self, target_node, option=None): + def cluster_failover( + self, target_node: "TargetNodesT", option: Optional[str] = None + ) -> ResponseT: """ Forces a slave to perform a manual failover of its master Sends to specified node @@ -422,7 +433,7 @@ class RedisClusterCommands( else: return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) - def cluster_info(self, target_nodes=None): + def cluster_info(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: """ Provides info about Redis Cluster node state. The command will be sent to a random node in the cluster if no target @@ -432,7 +443,7 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) - def cluster_keyslot(self, key): + def cluster_keyslot(self, key: str) -> ResponseT: """ Returns the hash slot of the specified key Sends to random node in the cluster @@ -441,7 +452,9 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER KEYSLOT", key) - def cluster_meet(self, host, port, target_nodes=None): + def cluster_meet( + self, host: str, port: int, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: """ Force a node cluster to handshake with another node. Sends to specified node. @@ -452,7 +465,7 @@ class RedisClusterCommands( "CLUSTER MEET", host, port, target_nodes=target_nodes ) - def cluster_nodes(self): + def cluster_nodes(self) -> ResponseT: """ Get Cluster config for the node. Sends to random node in the cluster @@ -461,7 +474,9 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER NODES") - def cluster_replicate(self, target_nodes, node_id): + def cluster_replicate( + self, target_nodes: "TargetNodesT", node_id: str + ) -> ResponseT: """ Reconfigure a node as a slave of the specified master node @@ -471,7 +486,9 @@ class RedisClusterCommands( "CLUSTER REPLICATE", node_id, target_nodes=target_nodes ) - def cluster_reset(self, soft=True, target_nodes=None): + def cluster_reset( + self, soft: bool = True, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: """ Reset a Redis Cluster node @@ -484,7 +501,9 @@ class RedisClusterCommands( "CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes ) - def cluster_save_config(self, target_nodes=None): + def cluster_save_config( + self, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: """ Forces the node to save cluster state on disk @@ -492,7 +511,7 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes) - def cluster_get_keys_in_slot(self, slot, num_keys): + def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> ResponseT: """ Returns the number of keys in the specified cluster slot @@ -500,7 +519,9 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys) - def cluster_set_config_epoch(self, epoch, target_nodes=None): + def cluster_set_config_epoch( + self, epoch: int, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: """ Set the configuration epoch in a new node @@ -510,7 +531,9 @@ class RedisClusterCommands( "CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes ) - def cluster_setslot(self, target_node, node_id, slot_id, state): + def cluster_setslot( + self, target_node: "TargetNodesT", node_id: str, slot_id: int, state: str + ) -> ResponseT: """ Bind an hash slot to a specific node @@ -528,7 +551,7 @@ class RedisClusterCommands( else: raise RedisError(f"Invalid slot state: {state}") - def cluster_setslot_stable(self, slot_id): + def cluster_setslot_stable(self, slot_id: int) -> ResponseT: """ Clears migrating / importing state from the slot. It determines by it self what node the slot is in and sends it there. @@ -537,7 +560,9 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE") - def cluster_replicas(self, node_id, target_nodes=None): + def cluster_replicas( + self, node_id: str, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: """ Provides a list of replica nodes replicating from the specified primary target node. @@ -548,7 +573,7 @@ class RedisClusterCommands( "CLUSTER REPLICAS", node_id, target_nodes=target_nodes ) - def cluster_slots(self, target_nodes=None): + def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: """ Get array of Cluster slot to node mappings @@ -556,7 +581,7 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) - def cluster_links(self, target_node): + def cluster_links(self, target_node: "TargetNodesT") -> ResponseT: """ Each node in a Redis Cluster maintains a pair of long-lived TCP link with each peer in the cluster: One for sending outbound messages towards the peer and one @@ -568,7 +593,7 @@ class RedisClusterCommands( """ return self.execute_command("CLUSTER LINKS", target_nodes=target_node) - def readonly(self, target_nodes=None): + def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: """ Enables read queries. The command will be sent to the default cluster node if target_nodes is @@ -582,7 +607,7 @@ class RedisClusterCommands( self.read_from_replicas = True return self.execute_command("READONLY", target_nodes=target_nodes) - def readwrite(self, target_nodes=None): + def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: """ Disables read queries. The command will be sent to the default cluster node if target_nodes is @@ -593,3 +618,227 @@ class RedisClusterCommands( # Reset read from replicas flag self.read_from_replicas = False return self.execute_command("READWRITE", target_nodes=target_nodes) + + +class AsyncClusterManagementCommands( + ClusterManagementCommands, AsyncManagementCommands +): + """ + A class for Redis Cluster management commands + + The class inherits from Redis's core ManagementCommands class and do the + required adjustments to work with cluster mode + """ + + async def cluster_delslots(self, *slots: EncodableT) -> List[bool]: + """ + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there + + Returns a list of the results for each processed slot. + + For more information see https://redis.io/commands/cluster-delslots + """ + return await asyncio.gather( + *( + asyncio.ensure_future(self.execute_command("CLUSTER DELSLOTS", slot)) + for slot in slots + ) + ) + + +class ClusterDataAccessCommands(DataAccessCommands): + """ + A class for Redis Cluster Data Access Commands + + The class inherits from Redis's core DataAccessCommand class and do the + required adjustments to work with cluster mode + """ + + def stralgo( + self, + algo: Literal["LCS"], + value1: KeyT, + value2: KeyT, + specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", + len: bool = False, + idx: bool = False, + minmatchlen: Optional[int] = None, + withmatchlen: bool = False, + **kwargs, + ) -> ResponseT: + """ + Implements complex algorithms that operate on strings. + Right now the only algorithm implemented is the LCS algorithm + (longest common substring). However new algorithms could be + implemented in the future. + + ``algo`` Right now must be LCS + ``value1`` and ``value2`` Can be two strings or two keys + ``specific_argument`` Specifying if the arguments to the algorithm + will be keys or strings. strings is the default. + ``len`` Returns just the len of the match. + ``idx`` Returns the match positions in each string. + ``minmatchlen`` Restrict the list of matches to the ones of a given + minimal length. Can be provided only when ``idx`` set to True. + ``withmatchlen`` Returns the matches with the len of the match. + Can be provided only when ``idx`` set to True. + + For more information see https://redis.io/commands/stralgo + """ + target_nodes = kwargs.pop("target_nodes", None) + if specific_argument == "strings" and target_nodes is None: + target_nodes = "default-node" + kwargs.update({"target_nodes": target_nodes}) + return super().stralgo( + algo, + value1, + value2, + specific_argument, + len, + idx, + minmatchlen, + withmatchlen, + **kwargs, + ) + + def scan_iter( + self, + match: Optional[PatternT] = None, + count: Optional[int] = None, + _type: Optional[str] = None, + **kwargs, + ) -> Iterator: + # Do the first query with cursor=0 for all nodes + cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs) + yield from data + + cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} + if cursors: + # Get nodes by name + nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + + # Iterate over each node till its cursor is 0 + kwargs.pop("target_nodes", None) + while cursors: + for name, cursor in cursors.items(): + cur, data = self.scan( + cursor=cursor, + match=match, + count=count, + _type=_type, + target_nodes=nodes[name], + **kwargs, + ) + yield from data + cursors[name] = cur[name] + + cursors = { + name: cursor for name, cursor in cursors.items() if cursor != 0 + } + + +class AsyncClusterDataAccessCommands( + ClusterDataAccessCommands, AsyncDataAccessCommands +): + """ + A class for Redis Cluster Data Access Commands + + The class inherits from Redis's core DataAccessCommand class and do the + required adjustments to work with cluster mode + """ + + async def scan_iter( + self, + match: Optional[PatternT] = None, + count: Optional[int] = None, + _type: Optional[str] = None, + **kwargs, + ) -> AsyncIterator: + # Do the first query with cursor=0 for all nodes + cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs) + for value in data: + yield value + + cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} + if cursors: + # Get nodes by name + nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + + # Iterate over each node till its cursor is 0 + kwargs.pop("target_nodes", None) + while cursors: + for name, cursor in cursors.items(): + cur, data = await self.scan( + cursor=cursor, + match=match, + count=count, + _type=_type, + target_nodes=nodes[name], + **kwargs, + ) + for value in data: + yield value + cursors[name] = cur[name] + + cursors = { + name: cursor for name, cursor in cursors.items() if cursor != 0 + } + + +class RedisClusterCommands( + ClusterMultiKeyCommands, + ClusterManagementCommands, + ACLCommands, + PubSubCommands, + ClusterDataAccessCommands, + ScriptCommands, + FunctionCommands, + RedisModuleCommands, +): + """ + A class for all Redis Cluster commands + + For key-based commands, the target node(s) will be internally determined + by the keys' hash slot. + Non-key-based commands can be executed with the 'target_nodes' argument to + target specific nodes. By default, if target_nodes is not specified, the + command will be executed on the default cluster node. + + :param :target_nodes: type can be one of the followings: + - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + r.cluster_info(target_nodes=RedisCluster.ALL_NODES) + """ + + +class AsyncRedisClusterCommands( + AsyncClusterMultiKeyCommands, + AsyncClusterManagementCommands, + AsyncACLCommands, + AsyncClusterDataAccessCommands, + AsyncScriptCommands, + AsyncFunctionCommands, +): + """ + A class for all Redis Cluster commands + + For key-based commands, the target node(s) will be internally determined + by the keys' hash slot. + Non-key-based commands can be executed with the 'target_nodes' argument to + target specific nodes. By default, if target_nodes is not specified, the + command will be executed on the default cluster node. + + :param :target_nodes: type can be one of the followings: + - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + r.cluster_info(target_nodes=RedisCluster.ALL_NODES) + """ diff --git a/redis/commands/parser.py b/redis/commands/parser.py index 89292ab..936f2ec 100644 --- a/redis/commands/parser.py +++ b/redis/commands/parser.py @@ -12,7 +12,6 @@ class CommandsParser: """ def __init__(self, redis_connection): - self.initialized = False self.commands = {} self.initialize(redis_connection) diff --git a/redis/crc.py b/redis/crc.py index c47e2ac..e261241 100644 --- a/redis/crc.py +++ b/redis/crc.py @@ -1,5 +1,7 @@ from binascii import crc_hqx +from redis.typing import EncodedT + # Redis Cluster's key space is divided into 16384 slots. # For more information see: https://github.com/redis/redis/issues/2576 REDIS_CLUSTER_HASH_SLOTS = 16384 @@ -7,7 +9,7 @@ REDIS_CLUSTER_HASH_SLOTS = 16384 __all__ = ["key_slot", "REDIS_CLUSTER_HASH_SLOTS"] -def key_slot(key, bucket=REDIS_CLUSTER_HASH_SLOTS): +def key_slot(key: EncodedT, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int: """Calculate key slot for a given key. See Keys distribution model in https://redis.io/topics/cluster-spec :param key - bytes diff --git a/redis/typing.py b/redis/typing.py index 73ae411..6748612 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,13 +1,14 @@ # from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Iterable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union from redis.compat import Protocol if TYPE_CHECKING: from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool - from redis.connection import ConnectionPool + from redis.asyncio.connection import Encoder as AsyncEncoder + from redis.connection import ConnectionPool, Encoder EncodedT = Union[bytes, memoryview] @@ -43,3 +44,10 @@ class CommandsProtocol(Protocol): def execute_command(self, *args, **options): ... + + +class ClusterCommandsProtocol(CommandsProtocol): + encoder: Union["AsyncEncoder", "Encoder"] + + def execute_command(self, *args, **options) -> Union[Any, Awaitable]: + ... diff --git a/tests/conftest.py b/tests/conftest.py index 903e961..e83c866 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,20 +199,21 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): def skip_if_server_version_lt(min_version: str) -> _TestDecorator: - redis_version = REDIS_INFO["version"] + redis_version = REDIS_INFO.get("version", "0") check = Version(redis_version) < Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}") def skip_if_server_version_gte(min_version: str) -> _TestDecorator: - redis_version = REDIS_INFO["version"] + redis_version = REDIS_INFO.get("version", "0") check = Version(redis_version) >= Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}") def skip_unless_arch_bits(arch_bits: int) -> _TestDecorator: return pytest.mark.skipif( - REDIS_INFO["arch_bits"] != arch_bits, reason=f"server is not {arch_bits}-bit" + REDIS_INFO.get("arch_bits", "") != arch_bits, + reason=f"server is not {arch_bits}-bit", ) @@ -235,12 +236,12 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): def skip_if_redis_enterprise() -> _TestDecorator: - check = REDIS_INFO["enterprise"] is True + check = REDIS_INFO.get("enterprise", False) is True return pytest.mark.skipif(check, reason="Redis enterprise") def skip_ifnot_redis_enterprise() -> _TestDecorator: - check = REDIS_INFO["enterprise"] is False + check = REDIS_INFO.get("enterprise", False) is False return pytest.mark.skipif(check, reason="Not running in redis enterprise") diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 02f2fc5..b8b5583 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -36,12 +36,18 @@ async def _get_info(redis_url): @pytest_asyncio.fixture( params=[ - (True, PythonParser), + pytest.param( + (True, PythonParser), + marks=pytest.mark.skipif( + REDIS_INFO["cluster_enabled"], reason="cluster mode enabled" + ), + ), (False, PythonParser), pytest.param( (True, HiredisParser), marks=pytest.mark.skipif( - not HIREDIS_AVAILABLE, reason="hiredis is not installed" + not HIREDIS_AVAILABLE or REDIS_INFO["cluster_enabled"], + reason="hiredis is not installed or cluster mode enabled", ), ), pytest.param( @@ -62,29 +68,51 @@ def create_redis(request, event_loop: asyncio.BaseEventLoop): """Wrapper around redis.create_redis.""" single_connection, parser_cls = request.param - async def f(url: str = request.config.getoption("--redis-url"), **kwargs): - single = kwargs.pop("single_connection_client", False) or single_connection - parser_class = kwargs.pop("parser_class", None) or parser_cls - url_options = parse_url(url) - url_options.update(kwargs) - pool = redis.ConnectionPool(parser_class=parser_class, **url_options) - client: redis.Redis = redis.Redis(connection_pool=pool) + async def f( + url: str = request.config.getoption("--redis-url"), + cls=redis.Redis, + flushdb=True, + **kwargs, + ): + cluster_mode = REDIS_INFO["cluster_enabled"] + if not cluster_mode: + single = kwargs.pop("single_connection_client", False) or single_connection + parser_class = kwargs.pop("parser_class", None) or parser_cls + url_options = parse_url(url) + url_options.update(kwargs) + pool = redis.ConnectionPool(parser_class=parser_class, **url_options) + client = cls(connection_pool=pool) + else: + client = redis.RedisCluster.from_url(url, **kwargs) + await client.initialize() + single = False if single: client = client.client() await client.initialize() def teardown(): async def ateardown(): - if "username" in kwargs: - return - try: - await client.flushdb() - except redis.ConnectionError: - # handle cases where a test disconnected a client - # just manually retry the flushdb - await client.flushdb() - await client.close() - await client.connection_pool.disconnect() + if not cluster_mode: + if "username" in kwargs: + return + if flushdb: + try: + await client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb() + await client.close() + await client.connection_pool.disconnect() + else: + if flushdb: + try: + await client.flushdb(target_nodes="primaries") + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb(target_nodes="primaries") + await client.close() if event_loop.is_running(): event_loop.create_task(ateardown()) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py new file mode 100644 index 0000000..6543e28 --- /dev/null +++ b/tests/test_asyncio/test_cluster.py @@ -0,0 +1,2232 @@ +import asyncio +import binascii +import datetime +import sys +import warnings + +import pytest + +from .compat import mock + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +from typing import Callable, Dict, List, Optional, Type, Union + +from _pytest.fixtures import FixtureRequest, SubRequest + +from redis.asyncio import Connection, RedisCluster +from redis.asyncio.cluster import ( + PRIMARY, + REDIS_CLUSTER_HASH_SLOTS, + REPLICA, + ClusterNode, + NodesManager, + get_node_name, +) +from redis.asyncio.parser import CommandsParser +from redis.crc import key_slot +from redis.exceptions import ( + AskError, + ClusterDownError, + ConnectionError, + DataError, + MovedError, + NoPermissionError, + RedisClusterException, + RedisError, + ResponseError, +) +from redis.utils import str_if_bytes +from tests.conftest import ( + skip_if_redis_enterprise, + skip_if_server_version_lt, + skip_unless_arch_bits, +) + +pytestmark = pytest.mark.asyncio + +default_host = "127.0.0.1" +default_port = 7000 +default_cluster_slots = [ + [0, 8191, ["127.0.0.1", 7000, "node_0"], ["127.0.0.1", 7003, "node_3"]], + [8192, 16383, ["127.0.0.1", 7001, "node_1"], ["127.0.0.1", 7002, "node_2"]], +] + + +@pytest_asyncio.fixture() +async def slowlog(request: SubRequest, r: RedisCluster) -> None: + """ + Set the slowlog threshold to 0, and the + max length to 128. This will force every + command into the slowlog and allow us + to test it + """ + # Save old values + current_config = await r.config_get(target_nodes=r.get_primaries()[0]) + old_slower_than_value = current_config["slowlog-log-slower-than"] + old_max_length_value = current_config["slowlog-max-len"] + + # Set the new values + await r.config_set("slowlog-log-slower-than", 0) + await r.config_set("slowlog-max-len", 128) + + yield + + await r.config_set("slowlog-log-slower-than", old_slower_than_value) + await r.config_set("slowlog-max-len", old_max_length_value) + + +async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster: + """ + Return a stable RedisCluster object that have deterministic + nodes and slots setup to remove the problem of different IP addresses + on different installations and machines. + """ + cluster_slots = kwargs.pop("cluster_slots", default_cluster_slots) + coverage_res = kwargs.pop("coverage_result", "yes") + cluster_enabled = kwargs.pop("cluster_enabled", True) + with mock.patch.object(ClusterNode, "execute_command") as execute_command_mock: + + async def execute_command(*_args, **_kwargs): + if _args[0] == "CLUSTER SLOTS": + mock_cluster_slots = cluster_slots + return mock_cluster_slots + elif _args[0] == "COMMAND": + return {"get": [], "set": []} + elif _args[0] == "INFO": + return {"cluster_enabled": cluster_enabled} + elif len(_args) > 1 and _args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": coverage_res} + else: + return await execute_command_mock(*_args, **_kwargs) + + execute_command_mock.side_effect = execute_command + + with mock.patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = { + "GET": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + + return await RedisCluster(*args, **kwargs) + + +def mock_node_resp( + node: ClusterNode, + response: Union[ + List[List[Union[int, List[Union[str, int]]]]], List[bytes], str, int + ], +) -> ClusterNode: + connection = mock.AsyncMock() + connection.is_connected = True + connection.read_response_without_lock.return_value = response + while node._free: + node._free.pop() + node._free.append(connection) + return node + + +def mock_all_nodes_resp( + rc: RedisCluster, + response: Union[ + List[List[Union[int, List[Union[str, int]]]]], List[bytes], int, str + ], +) -> RedisCluster: + for node in rc.get_nodes(): + mock_node_resp(node, response) + return rc + + +async def moved_redirection_helper( + request: FixtureRequest, create_redis: Callable, failover: bool = False +) -> None: + """ + Test that the client handles MOVED response after a failover. + Redirection after a failover means that the redirection address is of a + replica that was promoted to a primary. + + At first call it should return a MOVED ResponseError that will point + the client to the next server it should talk to. + + Verify that: + 1. it tries to talk to the redirected node + 2. it updates the slot's primary to the redirected node + + For a failover, also verify: + 3. the redirected node's server type updated to 'primary' + 4. the server type of the previous slot owner updated to 'replica' + """ + rc = await create_redis(cls=RedisCluster, flushdb=False) + slot = 12182 + redirect_node = None + # Get the current primary that holds this slot + prev_primary = rc.nodes_manager.get_node_from_slot(slot) + if failover: + if len(rc.nodes_manager.slots_cache[slot]) < 2: + warnings.warn("Skipping this test since it requires to have a " "replica") + return + redirect_node = rc.nodes_manager.slots_cache[slot][1] + else: + # Use one of the primaries to be the redirected node + redirect_node = rc.get_primaries()[0] + r_host = redirect_node.host + r_port = redirect_node.port + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: + + def moved_redirect_effect(self, *args, **options): + def ok_response(self, *args, **options): + assert self.host == r_host + assert self.port == r_port + + return "MOCK_OK" + + execute_command.side_effect = ok_response + raise MovedError(f"{slot} {r_host}:{r_port}") + + execute_command.side_effect = moved_redirect_effect + assert await rc.execute_command("SET", "foo", "bar") == "MOCK_OK" + slot_primary = rc.nodes_manager.slots_cache[slot][0] + assert slot_primary == redirect_node + if failover: + assert rc.get_node(host=r_host, port=r_port).server_type == PRIMARY + assert prev_primary.server_type == REPLICA + + +@pytest.mark.onlycluster +class TestRedisClusterObj: + """ + Tests for the RedisCluster class + """ + + async def test_host_port_startup_node(self) -> None: + """ + Test that it is possible to use host & port arguments as startup node + args + """ + cluster = await get_mocked_redis_client(host=default_host, port=default_port) + assert cluster.get_node(host=default_host, port=default_port) is not None + + await cluster.close() + + async def test_startup_nodes(self) -> None: + """ + Test that it is possible to use startup_nodes + argument to init the cluster + """ + port_1 = 7000 + port_2 = 7001 + startup_nodes = [ + ClusterNode(default_host, port_1), + ClusterNode(default_host, port_2), + ] + cluster = await get_mocked_redis_client(startup_nodes=startup_nodes) + assert ( + cluster.get_node(host=default_host, port=port_1) is not None + and cluster.get_node(host=default_host, port=port_2) is not None + ) + + await cluster.close() + + async def test_empty_startup_nodes(self) -> None: + """ + Test that exception is raised when empty providing empty startup_nodes + """ + with pytest.raises(RedisClusterException) as ex: + RedisCluster(startup_nodes=[]) + + assert str(ex.value).startswith( + "RedisCluster requires at least one node to discover the " "cluster" + ), str_if_bytes(ex.value) + + async def test_from_url(self, r: RedisCluster) -> None: + redis_url = f"redis://{default_host}:{default_port}/0" + with mock.patch.object(RedisCluster, "from_url") as from_url: + + async def from_url_mocked(_url, **_kwargs): + return await get_mocked_redis_client(url=_url, **_kwargs) + + from_url.side_effect = from_url_mocked + cluster = await RedisCluster.from_url(redis_url) + assert cluster.get_node(host=default_host, port=default_port) is not None + + await cluster.close() + + async def test_execute_command_errors(self, r: RedisCluster) -> None: + """ + Test that if no key is provided then exception should be raised. + """ + with pytest.raises(RedisClusterException) as ex: + await r.execute_command("GET") + assert str(ex.value).startswith( + "No way to dispatch this command to " "Redis Cluster. Missing key." + ) + + async def test_execute_command_node_flag_primaries(self, r: RedisCluster) -> None: + """ + Test command execution with nodes flag PRIMARIES + """ + primaries = r.get_primaries() + replicas = r.get_replicas() + mock_all_nodes_resp(r, "PONG") + assert await r.ping(target_nodes=RedisCluster.PRIMARIES) is True + for primary in primaries: + conn = primary._free.pop() + assert conn.read_response_without_lock.called is True + for replica in replicas: + conn = replica._free.pop() + assert conn.read_response_without_lock.called is not True + + async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None: + """ + Test command execution with nodes flag REPLICAS + """ + replicas = r.get_replicas() + if not replicas: + r = await get_mocked_redis_client(default_host, default_port) + primaries = r.get_primaries() + mock_all_nodes_resp(r, "PONG") + assert await r.ping(target_nodes=RedisCluster.REPLICAS) is True + for replica in replicas: + conn = replica._free.pop() + assert conn.read_response_without_lock.called is True + for primary in primaries: + conn = primary._free.pop() + assert conn.read_response_without_lock.called is not True + + await r.close() + + async def test_execute_command_node_flag_all_nodes(self, r: RedisCluster) -> None: + """ + Test command execution with nodes flag ALL_NODES + """ + mock_all_nodes_resp(r, "PONG") + assert await r.ping(target_nodes=RedisCluster.ALL_NODES) is True + for node in r.get_nodes(): + conn = node._free.pop() + assert conn.read_response_without_lock.called is True + + async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: + """ + Test command execution with nodes flag RANDOM + """ + mock_all_nodes_resp(r, "PONG") + assert await r.ping(target_nodes=RedisCluster.RANDOM) is True + called_count = 0 + for node in r.get_nodes(): + conn = node._free.pop() + if conn.read_response_without_lock.called is True: + called_count += 1 + assert called_count == 1 + + async def test_execute_command_default_node(self, r: RedisCluster) -> None: + """ + Test command execution without node flag is being executed on the + default node + """ + def_node = r.get_default_node() + mock_node_resp(def_node, "PONG") + assert await r.ping() is True + conn = def_node._free.pop() + assert conn.read_response_without_lock.called + + async def test_ask_redirection(self, r: RedisCluster) -> None: + """ + Test that the server handles ASK response. + + At first call it should return a ASK ResponseError that will point + the client to the next server it should talk to. + + Important thing to verify is that it tries to talk to the second node. + """ + redirect_node = r.get_nodes()[0] + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: + + def ask_redirect_effect(self, *args, **options): + def ok_response(self, *args, **options): + assert self.host == redirect_node.host + assert self.port == redirect_node.port + + return "MOCK_OK" + + execute_command.side_effect = ok_response + raise AskError(f"12182 {redirect_node.host}:{redirect_node.port}") + + execute_command.side_effect = ask_redirect_effect + + assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK" + + async def test_moved_redirection( + self, request: FixtureRequest, create_redis: Callable + ) -> None: + """ + Test that the client handles MOVED response. + """ + await moved_redirection_helper(request, create_redis, failover=False) + + async def test_moved_redirection_after_failover( + self, request: FixtureRequest, create_redis: Callable + ) -> None: + """ + Test that the client handles MOVED response after a failover. + """ + await moved_redirection_helper(request, create_redis, failover=True) + + async def test_refresh_using_specific_nodes( + self, request: FixtureRequest, create_redis: Callable + ) -> None: + """ + Test making calls on specific nodes when the cluster has failed over to + another node + """ + node_7006 = ClusterNode(host=default_host, port=7006, server_type=PRIMARY) + node_7007 = ClusterNode(host=default_host, port=7007, server_type=PRIMARY) + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: + with mock.patch.object( + NodesManager, "initialize", autospec=True + ) as initialize: + with mock.patch.multiple( + Connection, + send_packed_command=mock.DEFAULT, + connect=mock.DEFAULT, + can_read=mock.DEFAULT, + ) as mocks: + # simulate 7006 as a failed node + def execute_command_mock(self, *args, **options): + if self.port == 7006: + execute_command.failed_calls += 1 + raise ClusterDownError( + "CLUSTERDOWN The cluster is " + "down. Use CLUSTER INFO for " + "more information" + ) + elif self.port == 7007: + execute_command.successful_calls += 1 + + def initialize_mock(self): + # start with all slots mapped to 7006 + self.nodes_cache = {node_7006.name: node_7006} + self.default_node = node_7006 + self.slots_cache = {} + + for i in range(0, 16383): + self.slots_cache[i] = [node_7006] + + # After the first connection fails, a reinitialize + # should follow the cluster to 7007 + def map_7007(self): + self.nodes_cache = {node_7007.name: node_7007} + self.default_node = node_7007 + self.slots_cache = {} + + for i in range(0, 16383): + self.slots_cache[i] = [node_7007] + + # Change initialize side effect for the second call + initialize.side_effect = map_7007 + + execute_command.side_effect = execute_command_mock + execute_command.successful_calls = 0 + execute_command.failed_calls = 0 + initialize.side_effect = initialize_mock + mocks["can_read"].return_value = False + mocks["send_packed_command"].return_value = "MOCK_OK" + mocks["connect"].return_value = None + with mock.patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = { + "GET": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + + rc = await create_redis(cls=RedisCluster, flushdb=False) + assert len(rc.get_nodes()) == 1 + assert rc.get_node(node_name=node_7006.name) is not None + + await rc.get("foo") + + # Cluster should now point to 7007, and there should be + # one failed and one successful call + assert len(rc.get_nodes()) == 1 + assert rc.get_node(node_name=node_7007.name) is not None + assert rc.get_node(node_name=node_7006.name) is None + assert execute_command.failed_calls == 1 + assert execute_command.successful_calls == 1 + + async def test_reading_from_replicas_in_round_robin(self) -> None: + with mock.patch.multiple( + Connection, + send_command=mock.DEFAULT, + read_response_without_lock=mock.DEFAULT, + _connect=mock.DEFAULT, + can_read=mock.DEFAULT, + on_connect=mock.DEFAULT, + ) as mocks: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: + + async def execute_command_mock_first(self, *args, **options): + await self.connection_class(**self.connection_kwargs).connect() + # Primary + assert self.port == 7001 + execute_command.side_effect = execute_command_mock_second + return "MOCK_OK" + + def execute_command_mock_second(self, *args, **options): + # Replica + assert self.port == 7002 + execute_command.side_effect = execute_command_mock_third + return "MOCK_OK" + + def execute_command_mock_third(self, *args, **options): + # Primary + assert self.port == 7001 + return "MOCK_OK" + + # We don't need to create a real cluster connection but we + # do want RedisCluster.on_connect function to get called, + # so we'll mock some of the Connection's functions to allow it + execute_command.side_effect = execute_command_mock_first + mocks["send_command"].return_value = True + mocks["read_response_without_lock"].return_value = "OK" + mocks["_connect"].return_value = True + mocks["can_read"].return_value = False + mocks["on_connect"].return_value = True + + # Create a cluster with reading from replications + read_cluster = await get_mocked_redis_client( + host=default_host, port=default_port, read_from_replicas=True + ) + assert read_cluster.read_from_replicas is True + # Check that we read from the slot's nodes in a round robin + # matter. + # 'foo' belongs to slot 12182 and the slot's nodes are: + # [(127.0.0.1,7001,primary), (127.0.0.1,7002,replica)] + await read_cluster.get("foo") + await read_cluster.get("foo") + await read_cluster.get("foo") + mocks["send_command"].assert_has_calls([mock.call("READONLY")]) + + await read_cluster.close() + + async def test_keyslot(self, r: RedisCluster) -> None: + """ + Test that method will compute correct key in all supported cases + """ + assert r.keyslot("foo") == 12182 + assert r.keyslot("{foo}bar") == 12182 + assert r.keyslot("{foo}") == 12182 + assert r.keyslot(1337) == 4314 + + assert r.keyslot(125) == r.keyslot(b"125") + assert r.keyslot(125) == r.keyslot("\x31\x32\x35") + assert r.keyslot("大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") + assert r.keyslot("大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") + assert r.keyslot(1337.1234) == r.keyslot("1337.1234") + assert r.keyslot(1337) == r.keyslot("1337") + assert r.keyslot(b"abc") == r.keyslot("abc") + + async def test_get_node_name(self) -> None: + assert ( + get_node_name(default_host, default_port) + == f"{default_host}:{default_port}" + ) + + async def test_all_nodes(self, r: RedisCluster) -> None: + """ + Set a list of nodes and it should be possible to iterate over all + """ + nodes = [node for node in r.nodes_manager.nodes_cache.values()] + + for i, node in enumerate(r.get_nodes()): + assert node in nodes + + async def test_all_nodes_masters(self, r: RedisCluster) -> None: + """ + Set a list of nodes with random primaries/replicas config and it shold + be possible to iterate over all of them. + """ + nodes = [ + node + for node in r.nodes_manager.nodes_cache.values() + if node.server_type == PRIMARY + ] + + for node in r.get_primaries(): + assert node in nodes + + @pytest.mark.parametrize("error", RedisCluster.ERRORS_ALLOW_RETRY) + async def test_cluster_down_overreaches_retry_attempts( + self, + error: Union[Type[TimeoutError], Type[ClusterDownError], Type[ConnectionError]], + ) -> None: + """ + When error that allows retry is thrown, test that we retry executing + the command as many times as configured in cluster_error_retry_attempts + and then raise the exception + """ + with mock.patch.object(RedisCluster, "_execute_command") as execute_command: + + def raise_error(target_node, *args, **kwargs): + execute_command.failed_calls += 1 + raise error("mocked error") + + execute_command.side_effect = raise_error + + rc = await get_mocked_redis_client(host=default_host, port=default_port) + + with pytest.raises(error): + await rc.get("bar") + assert execute_command.failed_calls == rc.cluster_error_retry_attempts + + await rc.close() + + async def test_set_default_node_success(self, r: RedisCluster) -> None: + """ + test successful replacement of the default cluster node + """ + default_node = r.get_default_node() + # get a different node + new_def_node = None + for node in r.get_nodes(): + if node != default_node: + new_def_node = node + break + r.set_default_node(new_def_node) + assert r.get_default_node() == new_def_node + + async def test_set_default_node_failure(self, r: RedisCluster) -> None: + """ + test failed replacement of the default cluster node + """ + default_node = r.get_default_node() + new_def_node = ClusterNode("1.1.1.1", 1111) + with pytest.raises(DataError): + r.set_default_node(None) + with pytest.raises(DataError): + r.set_default_node(new_def_node) + assert r.get_default_node() == default_node + + async def test_get_node_from_key(self, r: RedisCluster) -> None: + """ + Test that get_node_from_key function returns the correct node + """ + key = "bar" + slot = r.keyslot(key) + slot_nodes = r.nodes_manager.slots_cache.get(slot) + primary = slot_nodes[0] + assert r.get_node_from_key(key, replica=False) == primary + replica = r.get_node_from_key(key, replica=True) + if replica is not None: + assert replica.server_type == REPLICA + assert replica in slot_nodes + + @skip_if_redis_enterprise() + async def test_not_require_full_coverage_cluster_down_error( + self, r: RedisCluster + ) -> None: + """ + When require_full_coverage is set to False (default client config) and not + all slots are covered, if one of the nodes has 'cluster-require_full_coverage' + config set to 'yes' some key-based commands should throw ClusterDownError + """ + node = r.get_node_from_key("foo") + missing_slot = r.keyslot("foo") + assert await r.set("foo", "bar") is True + try: + assert all(await r.cluster_delslots(missing_slot)) + with pytest.raises(ClusterDownError): + await r.exists("foo") + finally: + try: + # Add back the missing slot + assert await r.cluster_addslots(node, missing_slot) is True + # Make sure we are not getting ClusterDownError anymore + assert await r.exists("foo") == 1 + except ResponseError as e: + if f"Slot {missing_slot} is already busy" in str(e): + # It can happen if the test failed to delete this slot + pass + else: + raise e + + async def test_can_run_concurrent_commands(self, r: RedisCluster) -> None: + assert await r.ping(target_nodes=RedisCluster.ALL_NODES) is True + assert all( + await asyncio.gather( + *(r.ping(target_nodes=RedisCluster.ALL_NODES) for _ in range(100)) + ) + ) + + +@pytest.mark.onlycluster +class TestClusterRedisCommands: + """ + Tests for RedisCluster unique commands + """ + + async def test_get_and_set(self, r: RedisCluster) -> None: + # get and set can't be tested independently of each other + assert await r.get("a") is None + byte_string = b"value" + integer = 5 + unicode_string = chr(3456) + "abcd" + chr(3421) + assert await r.set("byte_string", byte_string) + assert await r.set("integer", 5) + assert await r.set("unicode_string", unicode_string) + assert await r.get("byte_string") == byte_string + assert await r.get("integer") == str(integer).encode() + assert (await r.get("unicode_string")).decode("utf-8") == unicode_string + + async def test_mget_nonatomic(self, r: RedisCluster) -> None: + assert await r.mget_nonatomic([]) == [] + assert await r.mget_nonatomic(["a", "b"]) == [None, None] + await r.set("a", "1") + await r.set("b", "2") + await r.set("c", "3") + + assert await r.mget_nonatomic("a", "other", "b", "c") == [ + b"1", + None, + b"2", + b"3", + ] + + async def test_mset_nonatomic(self, r: RedisCluster) -> None: + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + assert await r.mset_nonatomic(d) + for k, v in d.items(): + assert await r.get(k) == v + + async def test_config_set(self, r: RedisCluster) -> None: + assert await r.config_set("slowlog-log-slower-than", 0) + + async def test_cluster_config_resetstat(self, r: RedisCluster) -> None: + await r.ping(target_nodes="all") + all_info = await r.info(target_nodes="all") + prior_commands_processed = -1 + for node_info in all_info.values(): + prior_commands_processed = node_info["total_commands_processed"] + assert prior_commands_processed >= 1 + await r.config_resetstat(target_nodes="all") + all_info = await r.info(target_nodes="all") + for node_info in all_info.values(): + reset_commands_processed = node_info["total_commands_processed"] + assert reset_commands_processed < prior_commands_processed + + async def test_client_setname(self, r: RedisCluster) -> None: + node = r.get_random_node() + await r.client_setname("redis_py_test", target_nodes=node) + client_name = await r.client_getname(target_nodes=node) + assert client_name == "redis_py_test" + + async def test_exists(self, r: RedisCluster) -> None: + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + await r.mset_nonatomic(d) + assert await r.exists(*d.keys()) == len(d) + + async def test_delete(self, r: RedisCluster) -> None: + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + await r.mset_nonatomic(d) + assert await r.delete(*d.keys()) == len(d) + assert await r.delete(*d.keys()) == 0 + + async def test_touch(self, r: RedisCluster) -> None: + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + await r.mset_nonatomic(d) + assert await r.touch(*d.keys()) == len(d) + + async def test_unlink(self, r: RedisCluster) -> None: + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + await r.mset_nonatomic(d) + assert await r.unlink(*d.keys()) == len(d) + # Unlink is non-blocking so we sleep before + # verifying the deletion + await asyncio.sleep(0.1) + assert await r.unlink(*d.keys()) == 0 + + @skip_if_redis_enterprise() + async def test_cluster_myid(self, r: RedisCluster) -> None: + node = r.get_random_node() + myid = await r.cluster_myid(node) + assert len(myid) == 40 + + @skip_if_redis_enterprise() + async def test_cluster_slots(self, r: RedisCluster) -> None: + mock_all_nodes_resp(r, default_cluster_slots) + cluster_slots = await r.cluster_slots() + assert isinstance(cluster_slots, dict) + assert len(default_cluster_slots) == len(cluster_slots) + assert cluster_slots.get((0, 8191)) is not None + assert cluster_slots.get((0, 8191)).get("primary") == ("127.0.0.1", 7000) + + @skip_if_redis_enterprise() + async def test_cluster_addslots(self, r: RedisCluster) -> None: + node = r.get_random_node() + mock_node_resp(node, "OK") + assert await r.cluster_addslots(node, 1, 2, 3) is True + + @skip_if_server_version_lt("7.0.0") + @skip_if_redis_enterprise() + async def test_cluster_addslotsrange(self, r: RedisCluster): + node = r.get_random_node() + mock_node_resp(node, "OK") + assert await r.cluster_addslotsrange(node, 1, 5) + + @skip_if_redis_enterprise() + async def test_cluster_countkeysinslot(self, r: RedisCluster) -> None: + node = r.nodes_manager.get_node_from_slot(1) + mock_node_resp(node, 2) + assert await r.cluster_countkeysinslot(1) == 2 + + async def test_cluster_count_failure_report(self, r: RedisCluster) -> None: + mock_all_nodes_resp(r, 0) + assert await r.cluster_count_failure_report("node_0") == 0 + + @skip_if_redis_enterprise() + async def test_cluster_delslots(self) -> None: + cluster_slots = [ + [0, 8191, ["127.0.0.1", 7000, "node_0"]], + [8192, 16383, ["127.0.0.1", 7001, "node_1"]], + ] + r = await get_mocked_redis_client( + host=default_host, port=default_port, cluster_slots=cluster_slots + ) + mock_all_nodes_resp(r, "OK") + node0 = r.get_node(default_host, 7000) + node1 = r.get_node(default_host, 7001) + assert await r.cluster_delslots(0, 8192) == [True, True] + assert node0._free.pop().read_response_without_lock.called + assert node1._free.pop().read_response_without_lock.called + + await r.close() + + @skip_if_server_version_lt("7.0.0") + @skip_if_redis_enterprise() + async def test_cluster_delslotsrange(self, r: RedisCluster): + node = r.get_random_node() + mock_node_resp(node, "OK") + await r.cluster_addslots(node, 1, 2, 3, 4, 5) + assert await r.cluster_delslotsrange(1, 5) + + @skip_if_redis_enterprise() + async def test_cluster_failover(self, r: RedisCluster) -> None: + node = r.get_random_node() + mock_node_resp(node, "OK") + assert await r.cluster_failover(node) is True + assert await r.cluster_failover(node, "FORCE") is True + assert await r.cluster_failover(node, "TAKEOVER") is True + with pytest.raises(RedisError): + await r.cluster_failover(node, "FORCT") + + @skip_if_redis_enterprise() + async def test_cluster_info(self, r: RedisCluster) -> None: + info = await r.cluster_info() + assert isinstance(info, dict) + assert info["cluster_state"] == "ok" + + @skip_if_redis_enterprise() + async def test_cluster_keyslot(self, r: RedisCluster) -> None: + mock_all_nodes_resp(r, 12182) + assert await r.cluster_keyslot("foo") == 12182 + + @skip_if_redis_enterprise() + async def test_cluster_meet(self, r: RedisCluster) -> None: + node = r.get_default_node() + mock_node_resp(node, "OK") + assert await r.cluster_meet("127.0.0.1", 6379) is True + + @skip_if_redis_enterprise() + async def test_cluster_nodes(self, r: RedisCluster) -> None: + response = ( + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " + "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " + "1447836263059 5 connected\n" + "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " + "master - 0 1447836264065 0 connected\n" + "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " + "myself,master - 0 0 2 connected 5461-10922\n" + "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836262556 3 connected\n" + "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " + "master - 0 1447836262555 7 connected 0-5460\n" + "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " + "master - 0 1447836263562 3 connected 10923-16383\n" + "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " + "master,fail - 1447829446956 1447829444948 1 disconnected\n" + ) + mock_all_nodes_resp(r, response) + nodes = await r.cluster_nodes() + assert len(nodes) == 7 + assert nodes.get("172.17.0.7:7006") is not None + assert ( + nodes.get("172.17.0.7:7006").get("node_id") + == "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" + ) + + @skip_if_redis_enterprise() + async def test_cluster_nodes_importing_migrating(self, r: RedisCluster) -> None: + response = ( + "488ead2fcce24d8c0f158f9172cb1f4a9e040fe5 127.0.0.1:16381@26381 " + "master - 0 1648975557664 3 connected 10923-16383\n" + "8ae2e70812db80776f739a72374e57fc4ae6f89d 127.0.0.1:16380@26380 " + "master - 0 1648975555000 2 connected 1 5461-10922 [" + "2-<-ed8007ccfa2d91a7b76f8e6fba7ba7e257034a16]\n" + "ed8007ccfa2d91a7b76f8e6fba7ba7e257034a16 127.0.0.1:16379@26379 " + "myself,master - 0 1648975556000 1 connected 0 2-5460 [" + "2->-8ae2e70812db80776f739a72374e57fc4ae6f89d]\n" + ) + mock_all_nodes_resp(r, response) + nodes = await r.cluster_nodes() + assert len(nodes) == 3 + node_16379 = nodes.get("127.0.0.1:16379") + node_16380 = nodes.get("127.0.0.1:16380") + node_16381 = nodes.get("127.0.0.1:16381") + assert node_16379.get("migrations") == [ + { + "slot": "2", + "node_id": "8ae2e70812db80776f739a72374e57fc4ae6f89d", + "state": "migrating", + } + ] + assert node_16379.get("slots") == [["0"], ["2", "5460"]] + assert node_16380.get("migrations") == [ + { + "slot": "2", + "node_id": "ed8007ccfa2d91a7b76f8e6fba7ba7e257034a16", + "state": "importing", + } + ] + assert node_16380.get("slots") == [["1"], ["5461", "10922"]] + assert node_16381.get("slots") == [["10923", "16383"]] + assert node_16381.get("migrations") == [] + + @skip_if_redis_enterprise() + async def test_cluster_replicate(self, r: RedisCluster) -> None: + node = r.get_random_node() + all_replicas = r.get_replicas() + mock_all_nodes_resp(r, "OK") + assert await r.cluster_replicate(node, "c8253bae761cb61857d") is True + results = await r.cluster_replicate(all_replicas, "c8253bae761cb61857d") + if isinstance(results, dict): + for res in results.values(): + assert res is True + else: + assert results is True + + @skip_if_redis_enterprise() + async def test_cluster_reset(self, r: RedisCluster) -> None: + mock_all_nodes_resp(r, "OK") + assert await r.cluster_reset() is True + assert await r.cluster_reset(False) is True + all_results = await r.cluster_reset(False, target_nodes="all") + for res in all_results.values(): + assert res is True + + @skip_if_redis_enterprise() + async def test_cluster_save_config(self, r: RedisCluster) -> None: + node = r.get_random_node() + all_nodes = r.get_nodes() + mock_all_nodes_resp(r, "OK") + assert await r.cluster_save_config(node) is True + all_results = await r.cluster_save_config(all_nodes) + for res in all_results.values(): + assert res is True + + @skip_if_redis_enterprise() + async def test_cluster_get_keys_in_slot(self, r: RedisCluster) -> None: + response = ["{foo}1", "{foo}2"] + node = r.nodes_manager.get_node_from_slot(12182) + mock_node_resp(node, response) + keys = await r.cluster_get_keys_in_slot(12182, 4) + assert keys == response + + @skip_if_redis_enterprise() + async def test_cluster_set_config_epoch(self, r: RedisCluster) -> None: + mock_all_nodes_resp(r, "OK") + assert await r.cluster_set_config_epoch(3) is True + all_results = await r.cluster_set_config_epoch(3, target_nodes="all") + for res in all_results.values(): + assert res is True + + @skip_if_redis_enterprise() + async def test_cluster_setslot(self, r: RedisCluster) -> None: + node = r.get_random_node() + mock_node_resp(node, "OK") + assert await r.cluster_setslot(node, "node_0", 1218, "IMPORTING") is True + assert await r.cluster_setslot(node, "node_0", 1218, "NODE") is True + assert await r.cluster_setslot(node, "node_0", 1218, "MIGRATING") is True + with pytest.raises(RedisError): + await r.cluster_failover(node, "STABLE") + with pytest.raises(RedisError): + await r.cluster_failover(node, "STATE") + + async def test_cluster_setslot_stable(self, r: RedisCluster) -> None: + node = r.nodes_manager.get_node_from_slot(12182) + mock_node_resp(node, "OK") + assert await r.cluster_setslot_stable(12182) is True + assert node._free.pop().read_response_without_lock.called + + @skip_if_redis_enterprise() + async def test_cluster_replicas(self, r: RedisCluster) -> None: + response = [ + b"01eca22229cf3c652b6fca0d09ff6941e0d2e3 " + b"127.0.0.1:6377@16377 slave " + b"52611e796814b78e90ad94be9d769a4f668f9a 0 " + b"1634550063436 4 connected", + b"r4xfga22229cf3c652b6fca0d09ff69f3e0d4d " + b"127.0.0.1:6378@16378 slave " + b"52611e796814b78e90ad94be9d769a4f668f9a 0 " + b"1634550063436 4 connected", + ] + mock_all_nodes_resp(r, response) + replicas = await r.cluster_replicas("52611e796814b78e90ad94be9d769a4f668f9a") + assert replicas.get("127.0.0.1:6377") is not None + assert replicas.get("127.0.0.1:6378") is not None + assert ( + replicas.get("127.0.0.1:6378").get("node_id") + == "r4xfga22229cf3c652b6fca0d09ff69f3e0d4d" + ) + + @skip_if_server_version_lt("7.0.0") + async def test_cluster_links(self, r: RedisCluster): + node = r.get_random_node() + res = await r.cluster_links(node) + links_to = sum(x.count("to") for x in res) + links_for = sum(x.count("from") for x in res) + assert links_to == links_for + for i in range(0, len(res) - 1, 2): + assert res[i][3] == res[i + 1][3] + + @skip_if_redis_enterprise() + async def test_readonly(self) -> None: + r = await get_mocked_redis_client(host=default_host, port=default_port) + mock_all_nodes_resp(r, "OK") + assert await r.readonly() is True + all_replicas_results = await r.readonly(target_nodes="replicas") + for res in all_replicas_results.values(): + assert res is True + for replica in r.get_replicas(): + assert replica._free.pop().read_response_without_lock.called + + await r.close() + + @skip_if_redis_enterprise() + async def test_readwrite(self) -> None: + r = await get_mocked_redis_client(host=default_host, port=default_port) + mock_all_nodes_resp(r, "OK") + assert await r.readwrite() is True + all_replicas_results = await r.readwrite(target_nodes="replicas") + for res in all_replicas_results.values(): + assert res is True + for replica in r.get_replicas(): + assert replica._free.pop().read_response_without_lock.called + + await r.close() + + @skip_if_redis_enterprise() + async def test_bgsave(self, r: RedisCluster) -> None: + assert await r.bgsave() + await asyncio.sleep(0.3) + assert await r.bgsave(True) + + async def test_info(self, r: RedisCluster) -> None: + # Map keys to same slot + await r.set("x{1}", 1) + await r.set("y{1}", 2) + await r.set("z{1}", 3) + # Get node that handles the slot + slot = r.keyslot("x{1}") + node = r.nodes_manager.get_node_from_slot(slot) + # Run info on that node + info = await r.info(target_nodes=node) + assert isinstance(info, dict) + assert info["db0"]["keys"] == 3 + + async def _init_slowlog_test(self, r: RedisCluster, node: ClusterNode) -> str: + slowlog_lim = await r.config_get("slowlog-log-slower-than", target_nodes=node) + assert ( + await r.config_set("slowlog-log-slower-than", 0, target_nodes=node) is True + ) + return slowlog_lim["slowlog-log-slower-than"] + + async def _teardown_slowlog_test( + self, r: RedisCluster, node: ClusterNode, prev_limit: str + ) -> None: + assert ( + await r.config_set("slowlog-log-slower-than", prev_limit, target_nodes=node) + is True + ) + + async def test_slowlog_get( + self, r: RedisCluster, slowlog: Optional[List[Dict[str, Union[int, bytes]]]] + ) -> None: + unicode_string = chr(3456) + "abcd" + chr(3421) + node = r.get_node_from_key(unicode_string) + slowlog_limit = await self._init_slowlog_test(r, node) + assert await r.slowlog_reset(target_nodes=node) + await r.get(unicode_string) + slowlog = await r.slowlog_get(target_nodes=node) + assert isinstance(slowlog, list) + commands = [log["command"] for log in slowlog] + + get_command = b" ".join((b"GET", unicode_string.encode("utf-8"))) + assert get_command in commands + assert b"SLOWLOG RESET" in commands + + # the order should be ['GET <uni string>', 'SLOWLOG RESET'], + # but if other clients are executing commands at the same time, there + # could be commands, before, between, or after, so just check that + # the two we care about are in the appropriate order. + assert commands.index(get_command) < commands.index(b"SLOWLOG RESET") + + # make sure other attributes are typed correctly + assert isinstance(slowlog[0]["start_time"], int) + assert isinstance(slowlog[0]["duration"], int) + # rollback the slowlog limit to its original value + await self._teardown_slowlog_test(r, node, slowlog_limit) + + async def test_slowlog_get_limit( + self, r: RedisCluster, slowlog: Optional[List[Dict[str, Union[int, bytes]]]] + ) -> None: + assert await r.slowlog_reset() + node = r.get_node_from_key("foo") + slowlog_limit = await self._init_slowlog_test(r, node) + await r.get("foo") + slowlog = await r.slowlog_get(1, target_nodes=node) + assert isinstance(slowlog, list) + # only one command, based on the number we passed to slowlog_get() + assert len(slowlog) == 1 + await self._teardown_slowlog_test(r, node, slowlog_limit) + + async def test_slowlog_length(self, r: RedisCluster, slowlog: None) -> None: + await r.get("foo") + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + slowlog_len = await r.slowlog_len(target_nodes=node) + assert isinstance(slowlog_len, int) + + async def test_time(self, r: RedisCluster) -> None: + t = await r.time(target_nodes=r.get_primaries()[0]) + assert len(t) == 2 + assert isinstance(t[0], int) + assert isinstance(t[1], int) + + @skip_if_server_version_lt("4.0.0") + async def test_memory_usage(self, r: RedisCluster) -> None: + await r.set("foo", "bar") + assert isinstance(await r.memory_usage("foo"), int) + + @skip_if_server_version_lt("4.0.0") + @skip_if_redis_enterprise() + async def test_memory_malloc_stats(self, r: RedisCluster) -> None: + assert await r.memory_malloc_stats() + + @skip_if_server_version_lt("4.0.0") + @skip_if_redis_enterprise() + async def test_memory_stats(self, r: RedisCluster) -> None: + # put a key into the current db to make sure that "db.<current-db>" + # has data + await r.set("foo", "bar") + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + stats = await r.memory_stats(target_nodes=node) + assert isinstance(stats, dict) + for key, value in stats.items(): + if key.startswith("db."): + assert isinstance(value, dict) + + @skip_if_server_version_lt("4.0.0") + async def test_memory_help(self, r: RedisCluster) -> None: + with pytest.raises(NotImplementedError): + await r.memory_help() + + @skip_if_server_version_lt("4.0.0") + async def test_memory_doctor(self, r: RedisCluster) -> None: + with pytest.raises(NotImplementedError): + await r.memory_doctor() + + @skip_if_redis_enterprise() + async def test_lastsave(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + assert isinstance(await r.lastsave(target_nodes=node), datetime.datetime) + + async def test_cluster_echo(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + assert await r.echo("foo bar", target_nodes=node) == b"foo bar" + + @skip_if_server_version_lt("1.0.0") + async def test_debug_segfault(self, r: RedisCluster) -> None: + with pytest.raises(NotImplementedError): + await r.debug_segfault() + + async def test_config_resetstat(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + await r.ping(target_nodes=node) + prior_commands_processed = int( + (await r.info(target_nodes=node))["total_commands_processed"] + ) + assert prior_commands_processed >= 1 + await r.config_resetstat(target_nodes=node) + reset_commands_processed = int( + (await r.info(target_nodes=node))["total_commands_processed"] + ) + assert reset_commands_processed < prior_commands_processed + + @skip_if_server_version_lt("6.2.0") + async def test_client_trackinginfo(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + res = await r.client_trackinginfo(target_nodes=node) + assert len(res) > 2 + assert "prefixes" in res + + @skip_if_server_version_lt("2.9.50") + async def test_client_pause(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + assert await r.client_pause(1, target_nodes=node) + assert await r.client_pause(timeout=1, target_nodes=node) + with pytest.raises(RedisError): + await r.client_pause(timeout="not an integer", target_nodes=node) + + @skip_if_server_version_lt("6.2.0") + @skip_if_redis_enterprise() + async def test_client_unpause(self, r: RedisCluster) -> None: + assert await r.client_unpause() + + @skip_if_server_version_lt("5.0.0") + async def test_client_id(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + assert await r.client_id(target_nodes=node) > 0 + + @skip_if_server_version_lt("5.0.0") + async def test_client_unblock(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + myid = await r.client_id(target_nodes=node) + assert not await r.client_unblock(myid, target_nodes=node) + assert not await r.client_unblock(myid, error=True, target_nodes=node) + assert not await r.client_unblock(myid, error=False, target_nodes=node) + + @skip_if_server_version_lt("6.0.0") + async def test_client_getredir(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + assert isinstance(await r.client_getredir(target_nodes=node), int) + assert await r.client_getredir(target_nodes=node) == -1 + + @skip_if_server_version_lt("6.2.0") + async def test_client_info(self, r: RedisCluster) -> None: + node = r.get_primaries()[0] + info = await r.client_info(target_nodes=node) + assert isinstance(info, dict) + assert "addr" in info + + @skip_if_server_version_lt("2.6.9") + async def test_client_kill(self, r: RedisCluster, r2: RedisCluster) -> None: + node = r.get_primaries()[0] + await r.client_setname("redis-py-c1", target_nodes="all") + await r2.client_setname("redis-py-c2", target_nodes="all") + clients = [ + client + for client in await r.client_list(target_nodes=node) + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 2 + clients_by_name = {client.get("name"): client for client in clients} + + client_addr = clients_by_name["redis-py-c2"].get("addr") + assert await r.client_kill(client_addr, target_nodes=node) is True + + clients = [ + client + for client in await r.client_list(target_nodes=node) + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 1 + assert clients[0].get("name") == "redis-py-c1" + + @skip_if_server_version_lt("2.6.0") + async def test_cluster_bitop_not_empty_string(self, r: RedisCluster) -> None: + await r.set("{foo}a", "") + await r.bitop("not", "{foo}r", "{foo}a") + assert await r.get("{foo}r") is None + + @skip_if_server_version_lt("2.6.0") + async def test_cluster_bitop_not(self, r: RedisCluster) -> None: + test_str = b"\xAA\x00\xFF\x55" + correct = ~0xAA00FF55 & 0xFFFFFFFF + await r.set("{foo}a", test_str) + await r.bitop("not", "{foo}r", "{foo}a") + assert int(binascii.hexlify(await r.get("{foo}r")), 16) == correct + + @skip_if_server_version_lt("2.6.0") + async def test_cluster_bitop_not_in_place(self, r: RedisCluster) -> None: + test_str = b"\xAA\x00\xFF\x55" + correct = ~0xAA00FF55 & 0xFFFFFFFF + await r.set("{foo}a", test_str) + await r.bitop("not", "{foo}a", "{foo}a") + assert int(binascii.hexlify(await r.get("{foo}a")), 16) == correct + + @skip_if_server_version_lt("2.6.0") + async def test_cluster_bitop_single_string(self, r: RedisCluster) -> None: + test_str = b"\x01\x02\xFF" + await r.set("{foo}a", test_str) + await r.bitop("and", "{foo}res1", "{foo}a") + await r.bitop("or", "{foo}res2", "{foo}a") + await r.bitop("xor", "{foo}res3", "{foo}a") + assert await r.get("{foo}res1") == test_str + assert await r.get("{foo}res2") == test_str + assert await r.get("{foo}res3") == test_str + + @skip_if_server_version_lt("2.6.0") + async def test_cluster_bitop_string_operands(self, r: RedisCluster) -> None: + await r.set("{foo}a", b"\x01\x02\xFF\xFF") + await r.set("{foo}b", b"\x01\x02\xFF") + await r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") + await r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") + await r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") + assert int(binascii.hexlify(await r.get("{foo}res1")), 16) == 0x0102FF00 + assert int(binascii.hexlify(await r.get("{foo}res2")), 16) == 0x0102FFFF + assert int(binascii.hexlify(await r.get("{foo}res3")), 16) == 0x000000FF + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_copy(self, r: RedisCluster) -> None: + assert await r.copy("{foo}a", "{foo}b") == 0 + await r.set("{foo}a", "bar") + assert await r.copy("{foo}a", "{foo}b") == 1 + assert await r.get("{foo}a") == b"bar" + assert await r.get("{foo}b") == b"bar" + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_copy_and_replace(self, r: RedisCluster) -> None: + await r.set("{foo}a", "foo1") + await r.set("{foo}b", "foo2") + assert await r.copy("{foo}a", "{foo}b") == 0 + assert await r.copy("{foo}a", "{foo}b", replace=True) == 1 + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_lmove(self, r: RedisCluster) -> None: + await r.rpush("{foo}a", "one", "two", "three", "four") + assert await r.lmove("{foo}a", "{foo}b") + assert await r.lmove("{foo}a", "{foo}b", "right", "left") + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_blmove(self, r: RedisCluster) -> None: + await r.rpush("{foo}a", "one", "two", "three", "four") + assert await r.blmove("{foo}a", "{foo}b", 5) + assert await r.blmove("{foo}a", "{foo}b", 1, "RIGHT", "LEFT") + + async def test_cluster_msetnx(self, r: RedisCluster) -> None: + d = {"{foo}a": b"1", "{foo}b": b"2", "{foo}c": b"3"} + assert await r.msetnx(d) + d2 = {"{foo}a": b"x", "{foo}d": b"4"} + assert not await r.msetnx(d2) + for k, v in d.items(): + assert await r.get(k) == v + assert await r.get("{foo}d") is None + + async def test_cluster_rename(self, r: RedisCluster) -> None: + await r.set("{foo}a", "1") + assert await r.rename("{foo}a", "{foo}b") + assert await r.get("{foo}a") is None + assert await r.get("{foo}b") == b"1" + + async def test_cluster_renamenx(self, r: RedisCluster) -> None: + await r.set("{foo}a", "1") + await r.set("{foo}b", "2") + assert not await r.renamenx("{foo}a", "{foo}b") + assert await r.get("{foo}a") == b"1" + assert await r.get("{foo}b") == b"2" + + # LIST COMMANDS + async def test_cluster_blpop(self, r: RedisCluster) -> None: + await r.rpush("{foo}a", "1", "2") + await r.rpush("{foo}b", "3", "4") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) is None + await r.rpush("{foo}c", "1") + assert await r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + + async def test_cluster_brpop(self, r: RedisCluster) -> None: + await r.rpush("{foo}a", "1", "2") + await r.rpush("{foo}b", "3", "4") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) is None + await r.rpush("{foo}c", "1") + assert await r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + + async def test_cluster_brpoplpush(self, r: RedisCluster) -> None: + await r.rpush("{foo}a", "1", "2") + await r.rpush("{foo}b", "3", "4") + assert await r.brpoplpush("{foo}a", "{foo}b") == b"2" + assert await r.brpoplpush("{foo}a", "{foo}b") == b"1" + assert await r.brpoplpush("{foo}a", "{foo}b", timeout=1) is None + assert await r.lrange("{foo}a", 0, -1) == [] + assert await r.lrange("{foo}b", 0, -1) == [b"1", b"2", b"3", b"4"] + + async def test_cluster_brpoplpush_empty_string(self, r: RedisCluster) -> None: + await r.rpush("{foo}a", "") + assert await r.brpoplpush("{foo}a", "{foo}b") == b"" + + async def test_cluster_rpoplpush(self, r: RedisCluster) -> None: + await r.rpush("{foo}a", "a1", "a2", "a3") + await r.rpush("{foo}b", "b1", "b2", "b3") + assert await r.rpoplpush("{foo}a", "{foo}b") == b"a3" + assert await r.lrange("{foo}a", 0, -1) == [b"a1", b"a2"] + assert await r.lrange("{foo}b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] + + async def test_cluster_sdiff(self, r: RedisCluster) -> None: + await r.sadd("{foo}a", "1", "2", "3") + assert await r.sdiff("{foo}a", "{foo}b") == {b"1", b"2", b"3"} + await r.sadd("{foo}b", "2", "3") + assert await r.sdiff("{foo}a", "{foo}b") == {b"1"} + + async def test_cluster_sdiffstore(self, r: RedisCluster) -> None: + await r.sadd("{foo}a", "1", "2", "3") + assert await r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 3 + assert await r.smembers("{foo}c") == {b"1", b"2", b"3"} + await r.sadd("{foo}b", "2", "3") + assert await r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 1 + assert await r.smembers("{foo}c") == {b"1"} + + async def test_cluster_sinter(self, r: RedisCluster) -> None: + await r.sadd("{foo}a", "1", "2", "3") + assert await r.sinter("{foo}a", "{foo}b") == set() + await r.sadd("{foo}b", "2", "3") + assert await r.sinter("{foo}a", "{foo}b") == {b"2", b"3"} + + async def test_cluster_sinterstore(self, r: RedisCluster) -> None: + await r.sadd("{foo}a", "1", "2", "3") + assert await r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 0 + assert await r.smembers("{foo}c") == set() + await r.sadd("{foo}b", "2", "3") + assert await r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 2 + assert await r.smembers("{foo}c") == {b"2", b"3"} + + async def test_cluster_smove(self, r: RedisCluster) -> None: + await r.sadd("{foo}a", "a1", "a2") + await r.sadd("{foo}b", "b1", "b2") + assert await r.smove("{foo}a", "{foo}b", "a1") + assert await r.smembers("{foo}a") == {b"a2"} + assert await r.smembers("{foo}b") == {b"b1", b"b2", b"a1"} + + async def test_cluster_sunion(self, r: RedisCluster) -> None: + await r.sadd("{foo}a", "1", "2") + await r.sadd("{foo}b", "2", "3") + assert await r.sunion("{foo}a", "{foo}b") == {b"1", b"2", b"3"} + + async def test_cluster_sunionstore(self, r: RedisCluster) -> None: + await r.sadd("{foo}a", "1", "2") + await r.sadd("{foo}b", "2", "3") + assert await r.sunionstore("{foo}c", "{foo}a", "{foo}b") == 3 + assert await r.smembers("{foo}c") == {b"1", b"2", b"3"} + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_zdiff(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("{foo}b", {"a1": 1, "a2": 2}) + assert await r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] + assert await r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("{foo}b", {"a1": 1, "a2": 2}) + assert await r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) + assert await r.zrange("{foo}out", 0, -1) == [b"a3"] + assert await r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_zinter(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinter(["{foo}a", "{foo}b", "{foo}c"]) == [b"a3", b"a1"] + # invalid aggregation + with pytest.raises(DataError): + await r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True + ) + # aggregate with SUM + assert await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + (b"a3", 8), + (b"a1", 9), + ] + # aggregate with MAX + assert await r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [(b"a3", 5), (b"a1", 6)] + # aggregate with MIN + assert await r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [(b"a1", 1), (b"a3", 1)] + # with weights + assert await r.zinter( + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True + ) == [(b"a3", 20), (b"a1", 23)] + + async def test_cluster_zinterstore_sum(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a3", 8), + (b"a1", 9), + ] + + async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zinterstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX" + ) + == 2 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a3", 5), + (b"a1", 6), + ] + + async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("{foo}b", {"a1": 2, "a2": 3, "a3": 5}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zinterstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN" + ) + == 2 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a3", 3), + ] + + async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a3", 20), + (b"a1", 23), + ] + + @skip_if_server_version_lt("4.9.0") + async def test_cluster_bzpopmax(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 2}) + await r.zadd("{foo}b", {"b1": 10, "b2": 20}) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}b", + b"b2", + 20, + ) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}b", + b"b1", + 10, + ) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}a", + b"a2", + 2, + ) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}a", + b"a1", + 1, + ) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None + await r.zadd("{foo}c", {"c1": 100}) + assert await r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + + @skip_if_server_version_lt("4.9.0") + async def test_cluster_bzpopmin(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 2}) + await r.zadd("{foo}b", {"b1": 10, "b2": 20}) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}b", + b"b1", + 10, + ) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}b", + b"b2", + 20, + ) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}a", + b"a1", + 1, + ) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}a", + b"a2", + 2, + ) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None + await r.zadd("{foo}c", {"c1": 100}) + assert await r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_zrangestore(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zrangestore("{foo}b", "{foo}a", 0, 1) + assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] + assert await r.zrangestore("{foo}b", "{foo}a", 1, 2) + assert await r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] + assert await r.zrange("{foo}b", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a3", 3), + ] + # reversed order + assert await r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) + assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] + # by score + assert await r.zrangestore( + "{foo}b", "{foo}a", 2, 1, byscore=True, offset=0, num=1, desc=True + ) + assert await r.zrange("{foo}b", 0, -1) == [b"a2"] + # by lex + assert await r.zrangestore( + "{foo}b", "{foo}a", "[a2", "(a3", bylex=True, offset=0, num=1 + ) + assert await r.zrange("{foo}b", 0, -1) == [b"a2"] + + @skip_if_server_version_lt("6.2.0") + async def test_cluster_zunion(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + # sum + assert await r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [ + b"a2", + b"a4", + b"a3", + b"a1", + ] + assert await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + # max + assert await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + # min + assert await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + # with weight + assert await r.zunion( + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True + ) == [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)] + + async def test_cluster_zunionstore_sum(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + + async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zunionstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX" + ) + == 4 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] + + async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 4}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zunionstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN" + ) + == 4 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] + + async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + + @skip_if_server_version_lt("2.8.9") + async def test_cluster_pfcount(self, r: RedisCluster) -> None: + members = {b"1", b"2", b"3"} + await r.pfadd("{foo}a", *members) + assert await r.pfcount("{foo}a") == len(members) + members_b = {b"2", b"3", b"4"} + await r.pfadd("{foo}b", *members_b) + assert await r.pfcount("{foo}b") == len(members_b) + assert await r.pfcount("{foo}a", "{foo}b") == len(members_b.union(members)) + + @skip_if_server_version_lt("2.8.9") + async def test_cluster_pfmerge(self, r: RedisCluster) -> None: + mema = {b"1", b"2", b"3"} + memb = {b"2", b"3", b"4"} + memc = {b"5", b"6", b"7"} + await r.pfadd("{foo}a", *mema) + await r.pfadd("{foo}b", *memb) + await r.pfadd("{foo}c", *memc) + await r.pfmerge("{foo}d", "{foo}c", "{foo}a") + assert await r.pfcount("{foo}d") == 6 + await r.pfmerge("{foo}d", "{foo}b") + assert await r.pfcount("{foo}d") == 7 + + async def test_cluster_sort_store(self, r: RedisCluster) -> None: + await r.rpush("{foo}a", "2", "3", "1") + assert await r.sort("{foo}a", store="{foo}sorted_values") == 3 + assert await r.lrange("{foo}sorted_values", 0, -1) == [b"1", b"2", b"3"] + + # GEO COMMANDS + @skip_if_server_version_lt("6.2.0") + async def test_cluster_geosearchstore(self, r: RedisCluster) -> None: + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("{foo}barcelona", values) + await r.geosearchstore( + "{foo}places_barcelona", + "{foo}barcelona", + longitude=2.191, + latitude=41.433, + radius=1000, + ) + assert await r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt("6.2.0") + async def test_geosearchstore_dist(self, r: RedisCluster) -> None: + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("{foo}barcelona", values) + await r.geosearchstore( + "{foo}places_barcelona", + "{foo}barcelona", + longitude=2.191, + latitude=41.433, + radius=1000, + storedist=True, + ) + # instead of save the geo score, the distance is saved. + assert await r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 + + @skip_if_server_version_lt("3.2.0") + async def test_cluster_georadius_store(self, r: RedisCluster) -> None: + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("{foo}barcelona", values) + await r.georadius( + "{foo}barcelona", 2.191, 41.433, 1000, store="{foo}places_barcelona" + ) + assert await r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt("3.2.0") + async def test_cluster_georadius_store_dist(self, r: RedisCluster) -> None: + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("{foo}barcelona", values) + await r.georadius( + "{foo}barcelona", 2.191, 41.433, 1000, store_dist="{foo}places_barcelona" + ) + # instead of save the geo score, the distance is saved. + assert await r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 + + async def test_cluster_dbsize(self, r: RedisCluster) -> None: + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + assert await r.mset_nonatomic(d) + assert await r.dbsize(target_nodes="primaries") == len(d) + + async def test_cluster_keys(self, r: RedisCluster) -> None: + assert await r.keys() == [] + keys_with_underscores = {b"test_a", b"test_b"} + keys = keys_with_underscores.union({b"testc"}) + for key in keys: + await r.set(key, 1) + assert ( + set(await r.keys(pattern="test_*", target_nodes="primaries")) + == keys_with_underscores + ) + assert set(await r.keys(pattern="test*", target_nodes="primaries")) == keys + + # SCAN COMMANDS + @skip_if_server_version_lt("2.8.0") + async def test_cluster_scan(self, r: RedisCluster) -> None: + await r.set("a", 1) + await r.set("b", 2) + await r.set("c", 3) + + for target_nodes, nodes in zip( + ["primaries", "replicas"], [r.get_primaries(), r.get_replicas()] + ): + cursors, keys = await r.scan(target_nodes=target_nodes) + assert sorted(keys) == [b"a", b"b", b"c"] + assert sorted(cursors.keys()) == sorted(node.name for node in nodes) + assert all(cursor == 0 for cursor in cursors.values()) + + cursors, keys = await r.scan(match="a*", target_nodes=target_nodes) + assert sorted(keys) == [b"a"] + assert sorted(cursors.keys()) == sorted(node.name for node in nodes) + assert all(cursor == 0 for cursor in cursors.values()) + + @skip_if_server_version_lt("6.0.0") + async def test_cluster_scan_type(self, r: RedisCluster) -> None: + await r.sadd("a-set", 1) + await r.sadd("b-set", 1) + await r.sadd("c-set", 1) + await r.hset("a-hash", "foo", 2) + await r.lpush("a-list", "aux", 3) + + for target_nodes, nodes in zip( + ["primaries", "replicas"], [r.get_primaries(), r.get_replicas()] + ): + cursors, keys = await r.scan(_type="SET", target_nodes=target_nodes) + assert sorted(keys) == [b"a-set", b"b-set", b"c-set"] + assert sorted(cursors.keys()) == sorted(node.name for node in nodes) + assert all(cursor == 0 for cursor in cursors.values()) + + cursors, keys = await r.scan( + _type="SET", match="a*", target_nodes=target_nodes + ) + assert sorted(keys) == [b"a-set"] + assert sorted(cursors.keys()) == sorted(node.name for node in nodes) + assert all(cursor == 0 for cursor in cursors.values()) + + @skip_if_server_version_lt("2.8.0") + async def test_cluster_scan_iter(self, r: RedisCluster) -> None: + keys_all = [] + keys_1 = [] + for i in range(100): + s = str(i) + await r.set(s, 1) + keys_all.append(s.encode("utf-8")) + if s.startswith("1"): + keys_1.append(s.encode("utf-8")) + keys_all.sort() + keys_1.sort() + + for target_nodes in ["primaries", "replicas"]: + keys = [key async for key in r.scan_iter(target_nodes=target_nodes)] + assert sorted(keys) == keys_all + + keys = [ + key async for key in r.scan_iter(match="1*", target_nodes=target_nodes) + ] + assert sorted(keys) == keys_1 + + async def test_cluster_randomkey(self, r: RedisCluster) -> None: + node = r.get_node_from_key("{foo}") + assert await r.randomkey(target_nodes=node) is None + for key in ("{foo}a", "{foo}b", "{foo}c"): + await r.set(key, 1) + assert await r.randomkey(target_nodes=node) in (b"{foo}a", b"{foo}b", b"{foo}c") + + @skip_if_server_version_lt("6.0.0") + @skip_if_redis_enterprise() + async def test_acl_log( + self, r: RedisCluster, request: FixtureRequest, create_redis: Callable + ) -> None: + key = "{cache}:" + node = r.get_node_from_key(key) + username = "redis-py-user" + + await r.acl_setuser( + username, + enabled=True, + reset=True, + commands=["+get", "+set", "+select", "+cluster", "+command", "+info"], + keys=["{cache}:*"], + nopass=True, + target_nodes="primaries", + ) + await r.acl_log_reset(target_nodes=node) + + user_client = await create_redis( + cls=RedisCluster, flushdb=False, username=username + ) + + # Valid operation and key + assert await user_client.set("{cache}:0", 1) + assert await user_client.get("{cache}:0") == b"1" + + # Invalid key + with pytest.raises(NoPermissionError): + await user_client.get("{cache}violated_cache:0") + + # Invalid operation + with pytest.raises(NoPermissionError): + await user_client.hset("{cache}:0", "hkey", "hval") + + assert isinstance(await r.acl_log(target_nodes=node), list) + assert len(await r.acl_log(target_nodes=node)) == 2 + assert len(await r.acl_log(count=1, target_nodes=node)) == 1 + assert isinstance((await r.acl_log(target_nodes=node))[0], dict) + assert "client-info" in (await r.acl_log(count=1, target_nodes=node))[0] + assert await r.acl_log_reset(target_nodes=node) + + await r.acl_deluser(username, target_nodes="primaries") + + await user_client.close() + + +@pytest.mark.onlycluster +class TestNodesManager: + """ + Tests for the NodesManager class + """ + + async def test_load_balancer(self, r: RedisCluster) -> None: + n_manager = r.nodes_manager + lb = n_manager.read_load_balancer + slot_1 = 1257 + slot_2 = 8975 + node_1 = ClusterNode(default_host, 6379, PRIMARY) + node_2 = ClusterNode(default_host, 6378, REPLICA) + node_3 = ClusterNode(default_host, 6377, REPLICA) + node_4 = ClusterNode(default_host, 6376, PRIMARY) + node_5 = ClusterNode(default_host, 6375, REPLICA) + n_manager.slots_cache = { + slot_1: [node_1, node_2, node_3], + slot_2: [node_4, node_5], + } + primary1_name = n_manager.slots_cache[slot_1][0].name + primary2_name = n_manager.slots_cache[slot_2][0].name + list1_size = len(n_manager.slots_cache[slot_1]) + list2_size = len(n_manager.slots_cache[slot_2]) + # slot 1 + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary1_name, list1_size) == 1 + assert lb.get_server_index(primary1_name, list1_size) == 2 + assert lb.get_server_index(primary1_name, list1_size) == 0 + # slot 2 + assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 1 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + lb.reset() + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + async def test_init_slots_cache_not_all_slots_covered(self) -> None: + """ + Test that if not all slots are covered it should raise an exception + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], + ] + with pytest.raises(RedisClusterException) as ex: + rc = await get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=True, + ) + await rc.close() + assert str(ex.value).startswith( + "All slots are not covered after query all startup_nodes." + ) + + async def test_init_slots_cache_not_require_full_coverage_success(self) -> None: + """ + When require_full_coverage is set to False and not all slots are + covered the cluster client initialization should succeed + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], + ] + + rc = await get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + ) + + assert 5460 not in rc.nodes_manager.slots_cache + + await rc.close() + + async def test_init_slots_cache(self) -> None: + """ + Test that slots cache can in initialized and all slots are covered + """ + good_slots_resp = [ + [0, 5460, ["127.0.0.1", 7000], ["127.0.0.2", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.2", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.2", 7005]], + ] + + rc = await get_mocked_redis_client( + host=default_host, port=default_port, cluster_slots=good_slots_resp + ) + n_manager = rc.nodes_manager + assert len(n_manager.slots_cache) == REDIS_CLUSTER_HASH_SLOTS + for slot_info in good_slots_resp: + all_hosts = ["127.0.0.1", "127.0.0.2"] + all_ports = [7000, 7001, 7002, 7003, 7004, 7005] + slot_start = slot_info[0] + slot_end = slot_info[1] + for i in range(slot_start, slot_end + 1): + assert len(n_manager.slots_cache[i]) == len(slot_info[2:]) + assert n_manager.slots_cache[i][0].host in all_hosts + assert n_manager.slots_cache[i][1].host in all_hosts + assert n_manager.slots_cache[i][0].port in all_ports + assert n_manager.slots_cache[i][1].port in all_ports + + assert len(n_manager.nodes_cache) == 6 + + await rc.close() + + async def test_init_slots_cache_cluster_mode_disabled(self) -> None: + """ + Test that creating a RedisCluster failes if one of the startup nodes + has cluster mode disabled + """ + with pytest.raises(RedisClusterException) as e: + rc = await get_mocked_redis_client( + host=default_host, port=default_port, cluster_enabled=False + ) + await rc.close() + assert "Cluster mode is not enabled on this node" in str(e.value) + + async def test_empty_startup_nodes(self) -> None: + """ + It should not be possible to create a node manager with no nodes + specified + """ + with pytest.raises(RedisClusterException): + await NodesManager([]).initialize() + + async def test_wrong_startup_nodes_type(self) -> None: + """ + If something other then a list type itteratable is provided it should + fail + """ + with pytest.raises(RedisClusterException): + await NodesManager({}).initialize() + + async def test_init_slots_cache_slots_collision( + self, request: FixtureRequest + ) -> None: + """ + Test that if 2 nodes do not agree on the same slots setup it should + raise an error. In this test both nodes will say that the first + slots block should be bound to different servers. + """ + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: + + async def mocked_execute_command(self, *args, **kwargs): + """ + Helper function to return custom slots cache data from + different redis nodes + """ + if self.port == 7000: + result = [ + [0, 5460, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + ] + + elif self.port == 7001: + result = [ + [0, 5460, ["127.0.0.1", 7001], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7000], ["127.0.0.1", 7004]], + ] + else: + result = [] + + if args[0] == "CLUSTER SLOTS": + return result + elif args[0] == "INFO": + return {"cluster_enabled": True} + elif args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} + + execute_command.side_effect = mocked_execute_command + + with pytest.raises(RedisClusterException) as ex: + node_1 = ClusterNode("127.0.0.1", 7000) + node_2 = ClusterNode("127.0.0.1", 7001) + async with RedisCluster(startup_nodes=[node_1, node_2]): + ... + assert str(ex.value).startswith( + "startup_nodes could not agree on a valid slots cache" + ), str(ex.value) + + async def test_cluster_one_instance(self) -> None: + """ + If the cluster exists of only 1 node then there is some hacks that must + be validated they work. + """ + node = ClusterNode(default_host, default_port) + cluster_slots = [[0, 16383, ["", default_port]]] + rc = await get_mocked_redis_client( + startup_nodes=[node], cluster_slots=cluster_slots + ) + + n = rc.nodes_manager + assert len(n.nodes_cache) == 1 + n_node = rc.get_node(node_name=node.name) + assert n_node is not None + assert n_node == node + assert n_node.server_type == PRIMARY + assert len(n.slots_cache) == REDIS_CLUSTER_HASH_SLOTS + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + assert n.slots_cache[i] == [n_node] + + await rc.close() + + async def test_init_with_down_node(self) -> None: + """ + If I can't connect to one of the nodes, everything should still work. + But if I can't connect to any of the nodes, exception should be thrown. + """ + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: + + async def mocked_execute_command(self, *args, **kwargs): + if self.port == 7000: + raise ConnectionError("mock connection error for 7000") + + if args[0] == "CLUSTER SLOTS": + return [ + [0, 8191, ["127.0.0.1", 7001, "node_1"]], + [8192, 16383, ["127.0.0.1", 7002, "node_2"]], + ] + elif args[0] == "INFO": + return {"cluster_enabled": True} + elif args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} + + execute_command.side_effect = mocked_execute_command + + node_1 = ClusterNode("127.0.0.1", 7000) + node_2 = ClusterNode("127.0.0.1", 7001) + + # If all startup nodes fail to connect, connection error should be + # thrown + with pytest.raises(RedisClusterException) as e: + async with RedisCluster(startup_nodes=[node_1]): + ... + assert "Redis Cluster cannot be connected" in str(e.value) + + with mock.patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = { + "GET": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + # When at least one startup node is reachable, the cluster + # initialization should succeeds + async with RedisCluster(startup_nodes=[node_1, node_2]) as rc: + assert rc.get_node(host=default_host, port=7001) is not None + assert rc.get_node(host=default_host, port=7002) is not None diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index dee8755..650ce27 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -4,11 +4,17 @@ Tests async overrides of commands from their mixins import binascii import datetime import re +import sys import time from string import ascii_letters import pytest +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + import redis from redis import exceptions from redis.client import parse_info @@ -21,10 +27,10 @@ from tests.conftest import ( REDIS_6_VERSION = "5.9.0" -pytestmark = [pytest.mark.asyncio, pytest.mark.onlynoncluster] +pytestmark = pytest.mark.asyncio -@pytest.fixture() +@pytest_asyncio.fixture() async def slowlog(r: redis.Redis, event_loop): current_config = await r.config_get() old_slower_than_value = current_config["slowlog-log-slower-than"] @@ -53,6 +59,7 @@ async def get_stream_message(client: redis.Redis, stream: str, message_id: str): # RESPONSE CALLBACKS +@pytest.mark.onlynoncluster class TestResponseCallbacks: """Tests for the response callback system""" @@ -243,6 +250,7 @@ class TestRedisCommands: assert f"user {username} off sanitize-payload &* -@all" in users @skip_if_server_version_lt(REDIS_6_VERSION) + @pytest.mark.onlynoncluster async def test_acl_log(self, r: redis.Redis, request, event_loop, create_redis): username = "redis-py-user" @@ -350,6 +358,7 @@ class TestRedisCommands: username = await r.acl_whoami() assert isinstance(username, str) + @pytest.mark.onlynoncluster async def test_client_list(self, r: redis.Redis): clients = await r.client_list() assert isinstance(clients[0], dict) @@ -364,10 +373,12 @@ class TestRedisCommands: assert isinstance(clients, list) @skip_if_server_version_lt("5.0.0") + @pytest.mark.onlynoncluster async def test_client_id(self, r: redis.Redis): assert await r.client_id() > 0 @skip_if_server_version_lt("5.0.0") + @pytest.mark.onlynoncluster async def test_client_unblock(self, r: redis.Redis): myid = await r.client_id() assert not await r.client_unblock(myid) @@ -375,15 +386,18 @@ class TestRedisCommands: assert not await r.client_unblock(myid, error=False) @skip_if_server_version_lt("2.6.9") + @pytest.mark.onlynoncluster async def test_client_getname(self, r: redis.Redis): assert await r.client_getname() is None @skip_if_server_version_lt("2.6.9") + @pytest.mark.onlynoncluster async def test_client_setname(self, r: redis.Redis): assert await r.client_setname("redis_py_test") assert await r.client_getname() == "redis_py_test" @skip_if_server_version_lt("2.6.9") + @pytest.mark.onlynoncluster async def test_client_kill(self, r: redis.Redis, r2): await r.client_setname("redis-py-c1") await r2.client_setname("redis-py-c2") @@ -422,6 +436,7 @@ class TestRedisCommands: await r.client_kill_filter(_type="caster") # type: ignore @skip_if_server_version_lt("2.8.12") + @pytest.mark.onlynoncluster async def test_client_kill_filter_by_id(self, r: redis.Redis, r2): await r.client_setname("redis-py-c1") await r2.client_setname("redis-py-c2") @@ -447,6 +462,7 @@ class TestRedisCommands: assert clients[0].get("name") == "redis-py-c1" @skip_if_server_version_lt("2.8.12") + @pytest.mark.onlynoncluster async def test_client_kill_filter_by_addr(self, r: redis.Redis, r2): await r.client_setname("redis-py-c1") await r2.client_setname("redis-py-c2") @@ -479,6 +495,7 @@ class TestRedisCommands: assert "redis_py_test" in [c["name"] for c in clients] @skip_if_server_version_lt("2.9.50") + @pytest.mark.onlynoncluster async def test_client_pause(self, r: redis.Redis): assert await r.client_pause(1) assert await r.client_pause(timeout=1) @@ -490,6 +507,7 @@ class TestRedisCommands: assert "maxmemory" in data assert data["maxmemory"].isdigit() + @pytest.mark.onlynoncluster async def test_config_resetstat(self, r: redis.Redis): await r.ping() prior_commands_processed = int((await r.info())["total_commands_processed"]) @@ -507,14 +525,17 @@ class TestRedisCommands: finally: assert await r.config_set("dbfilename", rdbname) + @pytest.mark.onlynoncluster async def test_dbsize(self, r: redis.Redis): await r.set("a", "foo") await r.set("b", "bar") assert await r.dbsize() == 2 + @pytest.mark.onlynoncluster async def test_echo(self, r: redis.Redis): assert await r.echo("foo bar") == b"foo bar" + @pytest.mark.onlynoncluster async def test_info(self, r: redis.Redis): await r.set("a", "foo") await r.set("b", "bar") @@ -522,6 +543,7 @@ class TestRedisCommands: assert isinstance(info, dict) assert info["db9"]["keys"] == 2 + @pytest.mark.onlynoncluster async def test_lastsave(self, r: redis.Redis): assert isinstance(await r.lastsave(), datetime.datetime) @@ -535,6 +557,7 @@ class TestRedisCommands: async def test_ping(self, r: redis.Redis): assert await r.ping() + @pytest.mark.onlynoncluster async def test_slowlog_get(self, r: redis.Redis, slowlog): assert await r.slowlog_reset() unicode_string = chr(3456) + "abcd" + chr(3421) @@ -556,6 +579,7 @@ class TestRedisCommands: assert isinstance(slowlog[0]["start_time"], int) assert isinstance(slowlog[0]["duration"], int) + @pytest.mark.onlynoncluster async def test_slowlog_get_limit(self, r: redis.Redis, slowlog): assert await r.slowlog_reset() await r.get("foo") @@ -564,6 +588,7 @@ class TestRedisCommands: # only one command, based on the number we passed to slowlog_get() assert len(slowlog) == 1 + @pytest.mark.onlynoncluster async def test_slowlog_length(self, r: redis.Redis, slowlog): await r.get("foo") assert isinstance(await r.slowlog_len(), int) @@ -602,12 +627,14 @@ class TestRedisCommands: assert await r.bitcount("a", 1, 1) == 1 @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_not_empty_string(self, r: redis.Redis): await r.set("a", "") await r.bitop("not", "r", "a") assert await r.get("r") is None @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_not(self, r: redis.Redis): test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF @@ -616,6 +643,7 @@ class TestRedisCommands: assert int(binascii.hexlify(await r.get("r")), 16) == correct @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_not_in_place(self, r: redis.Redis): test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF @@ -624,6 +652,7 @@ class TestRedisCommands: assert int(binascii.hexlify(await r.get("a")), 16) == correct @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_single_string(self, r: redis.Redis): test_str = b"\x01\x02\xFF" await r.set("a", test_str) @@ -635,6 +664,7 @@ class TestRedisCommands: assert await r.get("res3") == test_str @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_string_operands(self, r: redis.Redis): await r.set("a", b"\x01\x02\xFF\xFF") await r.set("b", b"\x01\x02\xFF") @@ -645,6 +675,7 @@ class TestRedisCommands: assert int(binascii.hexlify(await r.get("res2")), 16) == 0x0102FFFF assert int(binascii.hexlify(await r.get("res3")), 16) == 0x000000FF + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.7") async def test_bitpos(self, r: redis.Redis): key = "key:bitpos" @@ -840,6 +871,7 @@ class TestRedisCommands: assert await r.incrbyfloat("a", 1.1) == 2.1 assert float(await r.get("a")) == float(2.1) + @pytest.mark.onlynoncluster async def test_keys(self, r: redis.Redis): assert await r.keys() == [] keys_with_underscores = {b"test_a", b"test_b"} @@ -849,6 +881,7 @@ class TestRedisCommands: assert set(await r.keys(pattern="test_*")) == keys_with_underscores assert set(await r.keys(pattern="test*")) == keys + @pytest.mark.onlynoncluster async def test_mget(self, r: redis.Redis): assert await r.mget([]) == [] assert await r.mget(["a", "b"]) == [None, None] @@ -857,12 +890,14 @@ class TestRedisCommands: await r.set("c", "3") assert await r.mget("a", "other", "b", "c") == [b"1", None, b"2", b"3"] + @pytest.mark.onlynoncluster async def test_mset(self, r: redis.Redis): d = {"a": b"1", "b": b"2", "c": b"3"} assert await r.mset(d) for k, v in d.items(): assert await r.get(k) == v + @pytest.mark.onlynoncluster async def test_msetnx(self, r: redis.Redis): d = {"a": b"1", "b": b"2", "c": b"3"} assert await r.msetnx(d) @@ -928,18 +963,21 @@ class TestRedisCommands: """PTTL on servers 2.8 and after return -2 when the key doesn't exist""" assert await r.pttl("a") == -2 + @pytest.mark.onlynoncluster async def test_randomkey(self, r: redis.Redis): assert await r.randomkey() is None for key in ("a", "b", "c"): await r.set(key, 1) assert await r.randomkey() in (b"a", b"b", b"c") + @pytest.mark.onlynoncluster async def test_rename(self, r: redis.Redis): await r.set("a", "1") assert await r.rename("a", "b") assert await r.get("a") is None assert await r.get("b") == b"1" + @pytest.mark.onlynoncluster async def test_renamenx(self, r: redis.Redis): await r.set("a", "1") await r.set("b", "2") @@ -1057,6 +1095,7 @@ class TestRedisCommands: assert await r.type("a") == b"zset" # LIST COMMANDS + @pytest.mark.onlynoncluster async def test_blpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") @@ -1068,6 +1107,7 @@ class TestRedisCommands: await r.rpush("c", "1") assert await r.blpop("c", timeout=1) == (b"c", b"1") + @pytest.mark.onlynoncluster async def test_brpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") @@ -1079,6 +1119,7 @@ class TestRedisCommands: await r.rpush("c", "1") assert await r.brpop("c", timeout=1) == (b"c", b"1") + @pytest.mark.onlynoncluster async def test_brpoplpush(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") @@ -1088,6 +1129,7 @@ class TestRedisCommands: assert await r.lrange("a", 0, -1) == [] assert await r.lrange("b", 0, -1) == [b"1", b"2", b"3", b"4"] + @pytest.mark.onlynoncluster async def test_brpoplpush_empty_string(self, r: redis.Redis): await r.rpush("a", "") assert await r.brpoplpush("a", "b") == b"" @@ -1165,6 +1207,7 @@ class TestRedisCommands: assert await r.rpop("a") == b"1" assert await r.rpop("a") is None + @pytest.mark.onlynoncluster async def test_rpoplpush(self, r: redis.Redis): await r.rpush("a", "a1", "a2", "a3") await r.rpush("b", "b1", "b2", "b3") @@ -1219,6 +1262,7 @@ class TestRedisCommands: # SCAN COMMANDS @skip_if_server_version_lt("2.8.0") + @pytest.mark.onlynoncluster async def test_scan(self, r: redis.Redis): await r.set("a", 1) await r.set("b", 2) @@ -1230,6 +1274,7 @@ class TestRedisCommands: assert set(keys) == {b"a"} @skip_if_server_version_lt(REDIS_6_VERSION) + @pytest.mark.onlynoncluster async def test_scan_type(self, r: redis.Redis): await r.sadd("a-set", 1) await r.hset("a-hash", "foo", 2) @@ -1238,6 +1283,7 @@ class TestRedisCommands: assert set(keys) == {b"a-set"} @skip_if_server_version_lt("2.8.0") + @pytest.mark.onlynoncluster async def test_scan_iter(self, r: redis.Redis): await r.set("a", 1) await r.set("b", 2) @@ -1308,12 +1354,14 @@ class TestRedisCommands: await r.sadd("a", "1", "2", "3") assert await r.scard("a") == 3 + @pytest.mark.onlynoncluster async def test_sdiff(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.sdiff("a", "b") == {b"1", b"2", b"3"} await r.sadd("b", "2", "3") assert await r.sdiff("a", "b") == {b"1"} + @pytest.mark.onlynoncluster async def test_sdiffstore(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.sdiffstore("c", "a", "b") == 3 @@ -1322,12 +1370,14 @@ class TestRedisCommands: assert await r.sdiffstore("c", "a", "b") == 1 assert await r.smembers("c") == {b"1"} + @pytest.mark.onlynoncluster async def test_sinter(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.sinter("a", "b") == set() await r.sadd("b", "2", "3") assert await r.sinter("a", "b") == {b"2", b"3"} + @pytest.mark.onlynoncluster async def test_sinterstore(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.sinterstore("c", "a", "b") == 0 @@ -1347,6 +1397,7 @@ class TestRedisCommands: await r.sadd("a", "1", "2", "3") assert await r.smembers("a") == {b"1", b"2", b"3"} + @pytest.mark.onlynoncluster async def test_smove(self, r: redis.Redis): await r.sadd("a", "a1", "a2") await r.sadd("b", "b1", "b2") @@ -1392,11 +1443,13 @@ class TestRedisCommands: assert await r.srem("a", "2", "4") == 2 assert await r.smembers("a") == {b"1", b"3"} + @pytest.mark.onlynoncluster async def test_sunion(self, r: redis.Redis): await r.sadd("a", "1", "2") await r.sadd("b", "2", "3") assert await r.sunion("a", "b") == {b"1", b"2", b"3"} + @pytest.mark.onlynoncluster async def test_sunionstore(self, r: redis.Redis): await r.sadd("a", "1", "2") await r.sadd("b", "2", "3") @@ -1481,6 +1534,7 @@ class TestRedisCommands: assert await r.zlexcount("a", "-", "+") == 7 assert await r.zlexcount("a", "[b", "[f") == 5 + @pytest.mark.onlynoncluster async def test_zinterstore_sum(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1488,6 +1542,7 @@ class TestRedisCommands: assert await r.zinterstore("d", ["a", "b", "c"]) == 2 assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + @pytest.mark.onlynoncluster async def test_zinterstore_max(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1495,6 +1550,7 @@ class TestRedisCommands: assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + @pytest.mark.onlynoncluster async def test_zinterstore_min(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) @@ -1502,6 +1558,7 @@ class TestRedisCommands: assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 assert await r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + @pytest.mark.onlynoncluster async def test_zinterstore_with_weight(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1526,6 +1583,7 @@ class TestRedisCommands: assert await r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] @skip_if_server_version_lt("4.9.0") + @pytest.mark.onlynoncluster async def test_bzpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) @@ -1538,6 +1596,7 @@ class TestRedisCommands: assert await r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) @skip_if_server_version_lt("4.9.0") + @pytest.mark.onlynoncluster async def test_bzpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) @@ -1705,6 +1764,7 @@ class TestRedisCommands: assert await r.zscore("a", "a2") == 2.0 assert await r.zscore("a", "a4") is None + @pytest.mark.onlynoncluster async def test_zunionstore_sum(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1717,6 +1777,7 @@ class TestRedisCommands: (b"a1", 9), ] + @pytest.mark.onlynoncluster async def test_zunionstore_max(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1729,6 +1790,7 @@ class TestRedisCommands: (b"a1", 6), ] + @pytest.mark.onlynoncluster async def test_zunionstore_min(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) @@ -1741,6 +1803,7 @@ class TestRedisCommands: (b"a4", 4), ] + @pytest.mark.onlynoncluster async def test_zunionstore_with_weight(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1762,6 +1825,7 @@ class TestRedisCommands: assert await r.pfcount("a") == len(members) @skip_if_server_version_lt("2.8.9") + @pytest.mark.onlynoncluster async def test_pfcount(self, r: redis.Redis): members = {b"1", b"2", b"3"} await r.pfadd("a", *members) @@ -1772,6 +1836,7 @@ class TestRedisCommands: assert await r.pfcount("a", "b") == len(members_b.union(members)) @skip_if_server_version_lt("2.8.9") + @pytest.mark.onlynoncluster async def test_pfmerge(self, r: redis.Redis): mema = {b"1", b"2", b"3"} memb = {b"2", b"3", b"4"} @@ -1866,7 +1931,8 @@ class TestRedisCommands: async def test_hmset(self, r: redis.Redis): warning_message = ( - r"^Redis\.hmset\(\) is deprecated\. " r"Use Redis\.hset\(\) instead\.$" + r"^Redis(?:Cluster)*\.hmset\(\) is deprecated\. " + r"Use Redis(?:Cluster)*\.hset\(\) instead\.$" ) h = {b"a": b"1", b"b": b"2", b"c": b"3"} with pytest.warns(DeprecationWarning, match=warning_message): @@ -1902,6 +1968,7 @@ class TestRedisCommands: await r.rpush("a", "3", "2", "1", "4") assert await r.sort("a", start=1, num=2) == [b"2", b"3"] + @pytest.mark.onlynoncluster async def test_sort_by(self, r: redis.Redis): await r.set("score:1", 8) await r.set("score:2", 3) @@ -1909,6 +1976,7 @@ class TestRedisCommands: await r.rpush("a", "3", "2", "1") assert await r.sort("a", by="score:*") == [b"2", b"3", b"1"] + @pytest.mark.onlynoncluster async def test_sort_get(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1916,6 +1984,7 @@ class TestRedisCommands: await r.rpush("a", "2", "3", "1") assert await r.sort("a", get="user:*") == [b"u1", b"u2", b"u3"] + @pytest.mark.onlynoncluster async def test_sort_get_multi(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1930,6 +1999,7 @@ class TestRedisCommands: b"3", ] + @pytest.mark.onlynoncluster async def test_sort_get_groups_two(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1941,6 +2011,7 @@ class TestRedisCommands: (b"u3", b"3"), ] + @pytest.mark.onlynoncluster async def test_sort_groups_string_get(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1949,6 +2020,7 @@ class TestRedisCommands: with pytest.raises(exceptions.DataError): await r.sort("a", get="user:*", groups=True) + @pytest.mark.onlynoncluster async def test_sort_groups_just_one_get(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1965,6 +2037,7 @@ class TestRedisCommands: with pytest.raises(exceptions.DataError): await r.sort("a", groups=True) + @pytest.mark.onlynoncluster async def test_sort_groups_three_gets(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1987,11 +2060,13 @@ class TestRedisCommands: await r.rpush("a", "e", "c", "b", "d", "a") assert await r.sort("a", alpha=True) == [b"a", b"b", b"c", b"d", b"e"] + @pytest.mark.onlynoncluster async def test_sort_store(self, r: redis.Redis): await r.rpush("a", "2", "3", "1") assert await r.sort("a", store="sorted_values") == 3 assert await r.lrange("sorted_values", 0, -1) == [b"1", b"2", b"3"] + @pytest.mark.onlynoncluster async def test_sort_all_options(self, r: redis.Redis): await r.set("user:1:username", "zeus") await r.set("user:2:username", "titan") @@ -2035,70 +2110,88 @@ class TestRedisCommands: await r.execute_command("SADD", "issue#924", 1) await r.execute_command("SORT", "issue#924") + @pytest.mark.onlynoncluster async def test_cluster_addslots(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("ADDSLOTS", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_count_failure_reports(self, mock_cluster_resp_int): assert isinstance( await mock_cluster_resp_int.cluster("COUNT-FAILURE-REPORTS", "node"), int ) + @pytest.mark.onlynoncluster async def test_cluster_countkeysinslot(self, mock_cluster_resp_int): assert isinstance( await mock_cluster_resp_int.cluster("COUNTKEYSINSLOT", 2), int ) + @pytest.mark.onlynoncluster async def test_cluster_delslots(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("DELSLOTS", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_failover(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("FAILOVER", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_forget(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("FORGET", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_info(self, mock_cluster_resp_info): assert isinstance(await mock_cluster_resp_info.cluster("info"), dict) + @pytest.mark.onlynoncluster async def test_cluster_keyslot(self, mock_cluster_resp_int): assert isinstance(await mock_cluster_resp_int.cluster("keyslot", "asdf"), int) + @pytest.mark.onlynoncluster async def test_cluster_meet(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("meet", "ip", "port", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_nodes(self, mock_cluster_resp_nodes): assert isinstance(await mock_cluster_resp_nodes.cluster("nodes"), dict) + @pytest.mark.onlynoncluster async def test_cluster_replicate(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("replicate", "nodeid") is True + @pytest.mark.onlynoncluster async def test_cluster_reset(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("reset", "hard") is True + @pytest.mark.onlynoncluster async def test_cluster_saveconfig(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("saveconfig") is True + @pytest.mark.onlynoncluster async def test_cluster_setslot(self, mock_cluster_resp_ok): assert ( await mock_cluster_resp_ok.cluster("setslot", 1, "IMPORTING", "nodeid") is True ) + @pytest.mark.onlynoncluster async def test_cluster_slaves(self, mock_cluster_resp_slaves): assert isinstance( await mock_cluster_resp_slaves.cluster("slaves", "nodeid"), dict ) @skip_if_server_version_lt("3.0.0") + @pytest.mark.onlynoncluster async def test_readwrite(self, r: redis.Redis): assert await r.readwrite() @skip_if_server_version_lt("3.0.0") + @pytest.mark.onlynoncluster async def test_readonly_invalid_cluster_state(self, r: redis.Redis): with pytest.raises(exceptions.RedisError): await r.readonly() @skip_if_server_version_lt("3.0.0") + @pytest.mark.onlynoncluster async def test_readonly(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.readonly() is True @@ -2315,6 +2408,7 @@ class TestRedisCommands: ] @skip_if_server_version_lt("3.2.0") + @pytest.mark.onlynoncluster async def test_georadius_store(self, r: redis.Redis): values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, @@ -2328,6 +2422,7 @@ class TestRedisCommands: @skip_unless_arch_bits(64) @skip_if_server_version_lt("3.2.0") + @pytest.mark.onlynoncluster async def test_georadius_store_dist(self, r: redis.Redis): values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, @@ -2723,25 +2818,11 @@ class TestRedisCommands: # xread starting at 0 returns both messages assert await r.xread(streams={stream: 0}) == expected - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - ], - ] - ] + expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] # xread starting at 0 and count=1 returns only the first message assert await r.xread(streams={stream: 0}, count=1) == expected - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m2), - ], - ] - ] + expected = [[stream.encode(), [await get_stream_message(r, stream, m2)]]] # xread starting at m1 returns only the second message assert await r.xread(streams={stream: m1}) == expected @@ -2772,14 +2853,7 @@ class TestRedisCommands: await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - ], - ] - ] + expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] # xread with count=1 returns only the first message assert ( await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) @@ -2817,15 +2891,7 @@ class TestRedisCommands: await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [ - [ - stream.encode(), - [ - (m1, {}), - (m2, {}), - ], - ] - ] + expected = [[stream.encode(), [(m1, {}), (m2, {})]]] await r.xreadgroup(group, consumer, streams={stream: ">"}) await r.xtrim(stream, 0) assert await r.xreadgroup(group, consumer, streams={stream: "0"}) == expected @@ -2872,6 +2938,7 @@ class TestRedisCommands: # 1 message is trimmed assert await r.xtrim(stream, 3, approximate=False) == 1 + @pytest.mark.onlynoncluster async def test_bitfield_operations(self, r: redis.Redis): # comments show affected bits await r.execute_command("SELECT", 10) @@ -2958,11 +3025,13 @@ class TestRedisCommands: assert isinstance(await r.memory_usage("foo"), int) @skip_if_server_version_lt("4.0.0") + @pytest.mark.onlynoncluster async def test_module_list(self, r: redis.Redis): assert isinstance(await r.module_list(), list) assert not await r.module_list() +@pytest.mark.onlynoncluster class TestBinarySave: async def test_binary_get_set(self, r: redis.Redis): assert await r.set(" foo bar ", "123") diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 46abec0..f6259ad 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -51,14 +51,12 @@ async def test_loading_external_modules(modclient): # assert mod.get('fookey') == d -@pytest.mark.onlynoncluster async def test_socket_param_regression(r): """A regression test for issue #1060""" conn = UnixDomainSocketConnection() _ = await conn.disconnect() is True -@pytest.mark.onlynoncluster async def test_can_run_concurrent_commands(r): assert await r.ping() is True assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index f9dfefd..6c56558 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -20,6 +20,7 @@ from .test_pubsub import wait_for_message pytestmark = pytest.mark.asyncio +@pytest.mark.onlynoncluster class TestRedisAutoReleaseConnectionPool: @pytest_asyncio.fixture async def r(self, create_redis) -> redis.Redis: @@ -112,7 +113,6 @@ class DummyConnection(Connection): return False -@pytest.mark.onlynoncluster class TestConnectionPool: def get_pool( self, @@ -189,7 +189,6 @@ class TestConnectionPool: assert repr(pool) == expected -@pytest.mark.onlynoncluster class TestBlockingConnectionPool: def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): connection_kwargs = connection_kwargs or {} @@ -296,38 +295,27 @@ class TestBlockingConnectionPool: assert repr(pool) == expected -@pytest.mark.onlynoncluster class TestConnectionPoolURLParsing: def test_hostname(self): pool = redis.ConnectionPool.from_url("redis://my.host") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "my.host", - } + assert pool.connection_kwargs == {"host": "my.host"} def test_quoted_hostname(self): pool = redis.ConnectionPool.from_url("redis://my %2F host %2B%3D+") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "my / host +=+", - } + assert pool.connection_kwargs == {"host": "my / host +=+"} def test_port(self): pool = redis.ConnectionPool.from_url("redis://localhost:6380") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6380, - } + assert pool.connection_kwargs == {"host": "localhost", "port": 6380} @skip_if_server_version_lt("6.0.0") def test_username(self): pool = redis.ConnectionPool.from_url("redis://myuser:@localhost") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "username": "myuser", - } + assert pool.connection_kwargs == {"host": "localhost", "username": "myuser"} @skip_if_server_version_lt("6.0.0") def test_quoted_username(self): @@ -343,10 +331,7 @@ class TestConnectionPoolURLParsing: def test_password(self): pool = redis.ConnectionPool.from_url("redis://:mypassword@localhost") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "password": "mypassword", - } + assert pool.connection_kwargs == {"host": "localhost", "password": "mypassword"} def test_quoted_password(self): pool = redis.ConnectionPool.from_url( @@ -371,26 +356,17 @@ class TestConnectionPoolURLParsing: def test_db_as_argument(self): pool = redis.ConnectionPool.from_url("redis://localhost", db=1) assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "db": 1, - } + assert pool.connection_kwargs == {"host": "localhost", "db": 1} def test_db_in_path(self): pool = redis.ConnectionPool.from_url("redis://localhost/2", db=1) assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "db": 2, - } + assert pool.connection_kwargs == {"host": "localhost", "db": 2} def test_db_in_querystring(self): pool = redis.ConnectionPool.from_url("redis://localhost/2?db=3", db=1) assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "db": 3, - } + assert pool.connection_kwargs == {"host": "localhost", "db": 3} def test_extra_typed_querystring_options(self): pool = redis.ConnectionPool.from_url( @@ -450,9 +426,7 @@ class TestConnectionPoolURLParsing: def test_client_creates_connection_pool(self): r = redis.Redis.from_url("redis://myhost") assert r.connection_pool.connection_class == redis.Connection - assert r.connection_pool.connection_kwargs == { - "host": "myhost", - } + assert r.connection_pool.connection_kwargs == {"host": "myhost"} def test_invalid_scheme_raises_error(self): with pytest.raises(ValueError) as cm: @@ -463,23 +437,17 @@ class TestConnectionPoolURLParsing: ) -@pytest.mark.onlynoncluster class TestConnectionPoolUnixSocketURLParsing: def test_defaults(self): pool = redis.ConnectionPool.from_url("unix:///socket") assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - } + assert pool.connection_kwargs == {"path": "/socket"} @skip_if_server_version_lt("6.0.0") def test_username(self): pool = redis.ConnectionPool.from_url("unix://myuser:@/socket") assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "username": "myuser", - } + assert pool.connection_kwargs == {"path": "/socket", "username": "myuser"} @skip_if_server_version_lt("6.0.0") def test_quoted_username(self): @@ -495,10 +463,7 @@ class TestConnectionPoolUnixSocketURLParsing: def test_password(self): pool = redis.ConnectionPool.from_url("unix://:mypassword@/socket") assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "password": "mypassword", - } + assert pool.connection_kwargs == {"path": "/socket", "password": "mypassword"} def test_quoted_password(self): pool = redis.ConnectionPool.from_url( @@ -523,18 +488,12 @@ class TestConnectionPoolUnixSocketURLParsing: def test_db_as_argument(self): pool = redis.ConnectionPool.from_url("unix:///socket", db=1) assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 1, - } + assert pool.connection_kwargs == {"path": "/socket", "db": 1} def test_db_in_querystring(self): pool = redis.ConnectionPool.from_url("unix:///socket?db=2", db=1) assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 2, - } + assert pool.connection_kwargs == {"path": "/socket", "db": 2} def test_client_name_in_querystring(self): pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") @@ -546,14 +505,11 @@ class TestConnectionPoolUnixSocketURLParsing: assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} -@pytest.mark.onlynoncluster class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url("rediss://my.host") assert pool.connection_class == redis.SSLConnection - assert pool.connection_kwargs == { - "host": "my.host", - } + assert pool.connection_kwargs == {"host": "my.host"} def test_cert_reqs_options(self): import ssl @@ -578,7 +534,6 @@ class TestSSLConnectionURLParsing: assert pool.get_connection("_").check_hostname is True -@pytest.mark.onlynoncluster class TestConnection: async def test_on_connect_error(self): """ @@ -709,7 +664,7 @@ class TestHealthCheck: def assert_interval_advanced(self, connection): diff = connection.next_health_check - asyncio.get_event_loop().time() - assert self.interval > diff > (self.interval - 1) + assert self.interval >= diff > (self.interval - 1) async def test_health_check_runs(self, r): if r.connection: diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index c496718..8ceb3bc 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -114,7 +114,7 @@ class TestLock: start = event_loop.time() assert not await lock2.acquire() # The elapsed duration should be less than the total blocking_timeout - assert bt > (event_loop.time() - start) > bt - sleep + assert bt >= (event_loop.time() - start) > bt - sleep await lock1.release() async def test_context_manager(self, r): diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index dee83ba..d696d72 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -19,7 +19,6 @@ class BackoffMock(AbstractBackoff): return 0 -@pytest.mark.onlynoncluster class TestConnectionConstructorWithRetry: "Test that the Connection constructors properly handles Retry objects" @@ -41,7 +40,6 @@ class TestConnectionConstructorWithRetry: assert c.retry._retries == retries -@pytest.mark.onlynoncluster class TestRetry: "Test that Retry calls backoff and retries the expected number of times" diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 3794c31..376e3f8 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -665,8 +665,8 @@ class TestClusterRedisCommands: def test_case_insensitive_command_names(self, r): assert ( - r.cluster_response_callbacks["cluster addslots"] - == r.cluster_response_callbacks["CLUSTER ADDSLOTS"] + r.cluster_response_callbacks["cluster slots"] + == r.cluster_response_callbacks["CLUSTER SLOTS"] ) def test_get_and_set(self, r): @@ -1038,7 +1038,7 @@ class TestClusterRedisCommands: @skip_if_redis_enterprise() def test_cluster_get_keys_in_slot(self, r): - response = [b"{foo}1", b"{foo}2"] + response = ["{foo}1", "{foo}2"] node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) keys = r.cluster_get_keys_in_slot(12182, 4) @@ -115,7 +115,7 @@ volumes = [docker:redismod_cluster] name = redismod_cluster image = redisfab/redis-py-modcluster:6.2.6 -ports = +ports = 46379:46379/tcp 46380:46380/tcp 46381:46381/tcp @@ -291,7 +291,7 @@ commands = standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' {posargs} standalone-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop {posargs} cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} --redis-unstable-url={env:UNSTABLE_CLUSTER_URL:} {posargs} - cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop {posargs} + cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} --redis-unstable-url={env:UNSTABLE_CLUSTER_URL:} --uvloop {posargs} [testenv:redis5] deps = @@ -337,7 +337,6 @@ skipsdist = true skip_install = true deps = -r {toxinidir}/dev_requirements.txt docker = {[testenv]docker} -commands = /usr/bin/echo docker_up [testenv:linters] deps_files = dev_requirements.txt diff --git a/whitelist.py b/whitelist.py index 2721028..8c9cee3 100644 --- a/whitelist.py +++ b/whitelist.py @@ -14,4 +14,6 @@ exc_type # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:2 exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) +AsyncEncoder # unused import (//data/repos/redis/redis-py/redis/typing.py:10) AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) +TargetNodesT # unused import (//data/repos/redis/redis-py/redis/commands/cluster.py:46) |