diff options
author | Kristján Valur Jónsson <sweskman@gmail.com> | 2023-05-02 15:06:29 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-02 18:06:29 +0300 |
commit | a7857e106bad02f4fc01c6ae69573d53d9018950 (patch) | |
tree | e38146914e690d7af6555ccfa0313110a57097f5 | |
parent | ac15d529edf2832af4c95349f6c0e9af2418448d (diff) | |
download | redis-py-a7857e106bad02f4fc01c6ae69573d53d9018950.tar.gz |
add "address_remap" feature to RedisCluster (#2726)
* add cluster "host_port_remap" feature for asyncio.RedisCluster
* Add a unittest for asyncio.RedisCluster
* Add host_port_remap to _sync_ RedisCluster
* add synchronous tests
* rename arg to `address_remap` and take and return an address tuple.
* Add class documentation
* Add CHANGES
-rw-r--r-- | CHANGES | 1 | ||||
-rw-r--r-- | redis/asyncio/cluster.py | 31 | ||||
-rw-r--r-- | redis/cluster.py | 22 | ||||
-rw-r--r-- | tests/test_asyncio/test_cluster.py | 110 | ||||
-rw-r--r-- | tests/test_cluster.py | 129 |
5 files changed, 291 insertions, 2 deletions
@@ -1,3 +1,4 @@ + * Add `address_remap` parameter to `RedisCluster` * Fix incorrect usage of once flag in async Sentinel * asyncio: Fix memory leak caused by hiredis (#2693) * Allow data to drain from async PythonParser when reading during a disconnect() diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index a4a9561..eb5f4db 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -5,12 +5,14 @@ import socket import warnings from typing import ( Any, + Callable, Deque, Dict, Generator, List, Mapping, Optional, + Tuple, Type, TypeVar, Union, @@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand maximum number of connections are already created, a :class:`~.MaxConnectionsError` is raised. This error may be retried as defined by :attr:`connection_error_retry_attempts` + :param address_remap: + | An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. | Rest of the arguments will be passed to the :class:`~redis.asyncio.connection.Connection` instances when created @@ -250,6 +258,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -337,7 +346,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs) + self.nodes_manager = NodesManager( + startup_nodes, + require_full_coverage, + kwargs, + address_remap=address_remap, + ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps @@ -1059,6 +1073,7 @@ class NodesManager: "require_full_coverage", "slots_cache", "startup_nodes", + "address_remap", ) def __init__( @@ -1066,10 +1081,12 @@ class NodesManager: startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs + self.address_remap = address_remap self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1228,6 +1245,7 @@ class NodesManager: if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: @@ -1246,6 +1264,7 @@ class NodesManager: for replica_node in replica_nodes: host = replica_node[0] port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = tmp_nodes_cache.get( get_node_name(host, port) @@ -1319,6 +1338,16 @@ class NodesManager: ) ) + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ diff --git a/redis/cluster.py b/redis/cluster.py index 5e6e7da..3ecc2da 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -466,6 +466,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): read_from_replicas: bool = False, dynamic_startup_nodes: bool = True, url: Optional[str] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): """ @@ -514,6 +515,12 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): reinitialize_steps to 1. To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. + :param address_remap: + An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. :**kwargs: Extra arguments that will be sent into Redis instance when created @@ -594,6 +601,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): from_url=from_url, require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, + address_remap=address_remap, **kwargs, ) @@ -1269,6 +1277,7 @@ class NodesManager: lock=None, dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): self.nodes_cache = {} @@ -1280,6 +1289,7 @@ class NodesManager: self._require_full_coverage = require_full_coverage self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class + self.address_remap = address_remap self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1502,6 +1512,7 @@ class NodesManager: if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = self._get_or_create_cluster_node( host, port, PRIMARY, tmp_nodes_cache @@ -1518,6 +1529,7 @@ class NodesManager: for replica_node in replica_nodes: host = str_if_bytes(replica_node[0]) port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = self._get_or_create_cluster_node( host, port, REPLICA, tmp_nodes_cache @@ -1591,6 +1603,16 @@ class NodesManager: # The read_load_balancer is None, do nothing pass + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPubSub(PubSub): """ diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 13e5e26..6d0aba7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -11,7 +11,7 @@ import pytest_asyncio from _pytest.fixtures import FixtureRequest from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster -from redis.asyncio.connection import Connection, SSLConnection +from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff @@ -49,6 +49,71 @@ default_cluster_slots = [ ] +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.send_event = asyncio.Event() + self.server = None + self.task = None + self.n_connections = 0 + + async def start(self): + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + try: + self.n_connections += 1 + pipe1 = asyncio.create_task(self.pipe(reader, redis_writer)) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer)) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() + + async def aclose(self): + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + await self.server.wait_closed() + + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + while True: + data = await reader.read(1000) + if not data: + break + writer.write(data) + await writer.drain() + + +@pytest.fixture +def redis_addr(request): + redis_url = request.config.getoption("--redis-url") + scheme, netloc = urlparse(redis_url)[:2] + assert scheme == "redis" + if ":" in netloc: + host, port = netloc.split(":") + return host, int(port) + else: + return netloc, 6379 + + @pytest_asyncio.fixture() async def slowlog(r: RedisCluster) -> None: """ @@ -809,6 +874,49 @@ class TestRedisClusterObj: # Rollback to the old default node r.replace_default_node(curr_default_node) + async def test_address_remap(self, create_redis, redis_addr): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + ports = [redis_addr[1] + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) + for port in ports + ] + await asyncio.gather(*[p.start() for p in proxies]) + try: + # create cluster: + r = await create_redis( + cls=RedisCluster, flushdb=False, address_remap=address_remap + ) + try: + assert await r.ping() is True + assert await r.set("byte_string", b"giraffe") + assert await r.get("byte_string") == b"giraffe" + finally: + await r.close() + finally: + await asyncio.gather(*[p.aclose() for p in proxies]) + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + class TestClusterRedisCommands: """ diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 58f9b77..1f037c9 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,9 +1,14 @@ import binascii import datetime +import select +import socket +import socketserver +import threading import warnings from queue import LifoQueue, Queue from time import sleep from unittest.mock import DEFAULT, Mock, call, patch +from urllib.parse import urlparse import pytest @@ -53,6 +58,85 @@ default_cluster_slots = [ ] +class ProxyRequestHandler(socketserver.BaseRequestHandler): + def recv(self, sock): + """A recv with a timeout""" + r = select.select([sock], [], [], 0.01) + if not r[0]: + return None + return sock.recv(1000) + + def handle(self): + self.server.proxy.n_connections += 1 + conn = socket.create_connection(self.server.proxy.redis_addr) + stop = False + + def from_server(): + # read from server and pass to client + while not stop: + data = self.recv(conn) + if data is None: + continue + if not data: + self.request.shutdown(socket.SHUT_WR) + return + self.request.sendall(data) + + thread = threading.Thread(target=from_server) + thread.start() + try: + while True: + # read from client and send to server + data = self.request.recv(1000) + if not data: + return + conn.sendall(data) + finally: + conn.shutdown(socket.SHUT_WR) + stop = True # for safety + thread.join() + conn.close() + + +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.server = socketserver.ThreadingTCPServer(self.addr, ProxyRequestHandler) + self.server.proxy = self + self.server.socket_reuse_address = True + self.thread = None + self.n_connections = 0 + + def start(self): + # test that we can connect to redis + s = socket.create_connection(self.redis_addr, timeout=2) + s.close() + # Start a thread with the server -- that thread will then start one + # more thread for each request + self.thread = threading.Thread(target=self.server.serve_forever) + # Exit the server thread when the main thread terminates + self.thread.daemon = True + self.thread.start() + + def close(self): + self.server.shutdown() + + +@pytest.fixture +def redis_addr(request): + redis_url = request.config.getoption("--redis-url") + scheme, netloc = urlparse(redis_url)[:2] + assert scheme == "redis" + if ":" in netloc: + host, port = netloc.split(":") + return host, int(port) + else: + return netloc, 6379 + + @pytest.fixture() def slowlog(request, r): """ @@ -823,6 +907,51 @@ class TestRedisClusterObj: assert "myself" not in nodes.get(curr_default_node.name).get("flags") assert r.get_default_node() != curr_default_node + def test_address_remap(self, request, redis_addr): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + ports = [redis_addr[1] + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) + for port in ports + ] + for p in proxies: + p.start() + try: + # create cluster: + r = _get_client( + RedisCluster, request, flushdb=False, address_remap=address_remap + ) + try: + assert r.ping() is True + assert r.set("byte_string", b"giraffe") + assert r.get("byte_string") == b"giraffe" + finally: + r.close() + finally: + for p in proxies: + p.close() + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + @pytest.mark.onlycluster class TestClusterRedisCommands: |