summaryrefslogtreecommitdiff
path: root/redis/commands/cluster.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/commands/cluster.py')
-rw-r--r--redis/commands/cluster.py204
1 files changed, 114 insertions, 90 deletions
diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py
index b91b65f..a1060d2 100644
--- a/redis/commands/cluster.py
+++ b/redis/commands/cluster.py
@@ -46,25 +46,111 @@ if TYPE_CHECKING:
from redis.asyncio.cluster import TargetNodesT
+# Not complete, but covers the major ones
+# https://redis.io/commands
+READ_COMMANDS = frozenset(
+ [
+ "BITCOUNT",
+ "BITPOS",
+ "EXISTS",
+ "GEODIST",
+ "GEOHASH",
+ "GEOPOS",
+ "GEORADIUS",
+ "GEORADIUSBYMEMBER",
+ "GET",
+ "GETBIT",
+ "GETRANGE",
+ "HEXISTS",
+ "HGET",
+ "HGETALL",
+ "HKEYS",
+ "HLEN",
+ "HMGET",
+ "HSTRLEN",
+ "HVALS",
+ "KEYS",
+ "LINDEX",
+ "LLEN",
+ "LRANGE",
+ "MGET",
+ "PTTL",
+ "RANDOMKEY",
+ "SCARD",
+ "SDIFF",
+ "SINTER",
+ "SISMEMBER",
+ "SMEMBERS",
+ "SRANDMEMBER",
+ "STRLEN",
+ "SUNION",
+ "TTL",
+ "ZCARD",
+ "ZCOUNT",
+ "ZRANGE",
+ "ZSCORE",
+ ]
+)
+
+
class ClusterMultiKeyCommands(ClusterCommandsProtocol):
"""
A class containing commands that handle more than one key
"""
def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]:
- """
- Split keys into a dictionary that maps a slot to
- a list of keys.
- """
+ """Split keys into a dictionary that maps a slot to a list of keys."""
+
slots_to_keys = {}
for key in keys:
- k = self.encoder.encode(key)
- slot = key_slot(k)
+ slot = key_slot(self.encoder.encode(key))
slots_to_keys.setdefault(slot, []).append(key)
return slots_to_keys
- def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]:
+ def _partition_pairs_by_slot(
+ self, mapping: Mapping[AnyKeyT, EncodableT]
+ ) -> Dict[int, List[EncodableT]]:
+ """Split pairs into a dictionary that maps a slot to a list of pairs."""
+
+ slots_to_pairs = {}
+ for pair in mapping.items():
+ slot = key_slot(self.encoder.encode(pair[0]))
+ slots_to_pairs.setdefault(slot, []).extend(pair)
+
+ return slots_to_pairs
+
+ def _execute_pipeline_by_slot(
+ self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
+ ) -> List[Any]:
+ read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
+ pipe = self.pipeline()
+ [
+ pipe.execute_command(
+ command,
+ *slot_args,
+ target_nodes=[
+ self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
+ ],
+ )
+ for slot, slot_args in slots_to_args.items()
+ ]
+ return pipe.execute()
+
+ def _reorder_keys_by_command(
+ self,
+ keys: Iterable[KeyT],
+ slots_to_args: Mapping[int, Iterable[EncodableT]],
+ responses: Iterable[Any],
+ ) -> List[Any]:
+ results = {
+ k: v
+ for slot_values, response in zip(slots_to_args.values(), responses)
+ for k, v in zip(slot_values, response)
+ }
+ return [results[key] for key in keys]
+
+ def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
"""
Splits the keys into different slots and then calls MGET
for the keys of every slot. This operation will not be atomic
@@ -75,30 +161,17 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
For more information see https://redis.io/commands/mget
"""
- from redis.client import EMPTY_RESPONSE
-
- options = {}
- if not args:
- options[EMPTY_RESPONSE] = []
-
# Concatenate all keys into a list
keys = list_or_args(keys, args)
+
# Split keys into slots
slots_to_keys = self._partition_keys_by_slot(keys)
- # Call MGET for every slot and concatenate
- # the results
- # We must make sure that the keys are returned in order
- all_results = {}
- for slot_keys in slots_to_keys.values():
- slot_values = self.execute_command("MGET", *slot_keys, **options)
+ # Execute commands using a pipeline
+ res = self._execute_pipeline_by_slot("MGET", slots_to_keys)
- slot_results = dict(zip(slot_keys, slot_values))
- all_results.update(slot_results)
-
- # Sort the results
- vals_in_order = [all_results[key] for key in keys]
- return vals_in_order
+ # Reorder keys in the order the user provided & return
+ return self._reorder_keys_by_command(keys, slots_to_keys, res)
def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]:
"""
@@ -114,35 +187,22 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
"""
# Partition the keys by slot
- slots_to_pairs = {}
- for pair in mapping.items():
- # encode the key
- k = self.encoder.encode(pair[0])
- slot = key_slot(k)
- slots_to_pairs.setdefault(slot, []).extend(pair)
-
- # Call MSET for every slot and concatenate
- # the results (one result per slot)
- res = []
- for pairs in slots_to_pairs.values():
- res.append(self.execute_command("MSET", *pairs))
+ slots_to_pairs = self._partition_pairs_by_slot(mapping)
- return res
+ # Execute commands using a pipeline & return list of replies
+ return self._execute_pipeline_by_slot("MSET", slots_to_pairs)
def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
"""
Runs the given command once for the keys
of each slot. Returns the sum of the return values.
"""
+
# Partition the keys by slot
slots_to_keys = self._partition_keys_by_slot(keys)
# Sum up the reply from each command
- total = 0
- for slot_keys in slots_to_keys.values():
- total += self.execute_command(command, *slot_keys)
-
- return total
+ return sum(self._execute_pipeline_by_slot(command, slots_to_keys))
def exists(self, *keys: KeyT) -> ResponseT:
"""
@@ -202,7 +262,7 @@ class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands):
A class containing commands that handle more than one key
"""
- async def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]:
+ async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
"""
Splits the keys into different slots and then calls MGET
for the keys of every slot. This operation will not be atomic
@@ -213,36 +273,17 @@ class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands):
For more information see https://redis.io/commands/mget
"""
- from redis.client import EMPTY_RESPONSE
-
- options = {}
- if not args:
- options[EMPTY_RESPONSE] = []
-
# Concatenate all keys into a list
keys = list_or_args(keys, args)
+
# Split keys into slots
slots_to_keys = self._partition_keys_by_slot(keys)
- # Call MGET for every slot and concatenate
- # the results
- # We must make sure that the keys are returned in order
- all_values = await asyncio.gather(
- *(
- asyncio.ensure_future(
- self.execute_command("MGET", *slot_keys, **options)
- )
- for slot_keys in slots_to_keys.values()
- )
- )
+ # Execute commands using a pipeline
+ res = await self._execute_pipeline_by_slot("MGET", slots_to_keys)
- all_results = {}
- for slot_keys, slot_values in zip(slots_to_keys.values(), all_values):
- all_results.update(dict(zip(slot_keys, slot_values)))
-
- # Sort the results
- vals_in_order = [all_results[key] for key in keys]
- return vals_in_order
+ # Reorder keys in the order the user provided & return
+ return self._reorder_keys_by_command(keys, slots_to_keys, res)
async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]:
"""
@@ -258,39 +299,22 @@ class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands):
"""
# Partition the keys by slot
- slots_to_pairs = {}
- for pair in mapping.items():
- # encode the key
- k = self.encoder.encode(pair[0])
- slot = key_slot(k)
- slots_to_pairs.setdefault(slot, []).extend(pair)
+ slots_to_pairs = self._partition_pairs_by_slot(mapping)
- # Call MSET for every slot and concatenate
- # the results (one result per slot)
- return await asyncio.gather(
- *(
- asyncio.ensure_future(self.execute_command("MSET", *pairs))
- for pairs in slots_to_pairs.values()
- )
- )
+ # Execute commands using a pipeline & return list of replies
+ return await self._execute_pipeline_by_slot("MSET", slots_to_pairs)
async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
"""
Runs the given command once for the keys
of each slot. Returns the sum of the return values.
"""
+
# Partition the keys by slot
slots_to_keys = self._partition_keys_by_slot(keys)
# Sum up the reply from each command
- return sum(
- await asyncio.gather(
- *(
- asyncio.ensure_future(self.execute_command(command, *slot_keys))
- for slot_keys in slots_to_keys.values()
- )
- )
- )
+ return sum(await self._execute_pipeline_by_slot(command, slots_to_keys))
class ClusterManagementCommands(ManagementCommands):