summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristján Valur Jónsson <sweskman@gmail.com>2023-05-02 15:06:29 +0000
committerGitHub <noreply@github.com>2023-05-02 18:06:29 +0300
commita7857e106bad02f4fc01c6ae69573d53d9018950 (patch)
treee38146914e690d7af6555ccfa0313110a57097f5
parentac15d529edf2832af4c95349f6c0e9af2418448d (diff)
downloadredis-py-a7857e106bad02f4fc01c6ae69573d53d9018950.tar.gz
add "address_remap" feature to RedisCluster (#2726)
* add cluster "host_port_remap" feature for asyncio.RedisCluster * Add a unittest for asyncio.RedisCluster * Add host_port_remap to _sync_ RedisCluster * add synchronous tests * rename arg to `address_remap` and take and return an address tuple. * Add class documentation * Add CHANGES
-rw-r--r--CHANGES1
-rw-r--r--redis/asyncio/cluster.py31
-rw-r--r--redis/cluster.py22
-rw-r--r--tests/test_asyncio/test_cluster.py110
-rw-r--r--tests/test_cluster.py129
5 files changed, 291 insertions, 2 deletions
diff --git a/CHANGES b/CHANGES
index 8f20172..3865ed1 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,4 @@
+ * Add `address_remap` parameter to `RedisCluster`
* Fix incorrect usage of once flag in async Sentinel
* asyncio: Fix memory leak caused by hiredis (#2693)
* Allow data to drain from async PythonParser when reading during a disconnect()
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index a4a9561..eb5f4db 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -5,12 +5,14 @@ import socket
import warnings
from typing import (
Any,
+ Callable,
Deque,
Dict,
Generator,
List,
Mapping,
Optional,
+ Tuple,
Type,
TypeVar,
Union,
@@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
maximum number of connections are already created, a
:class:`~.MaxConnectionsError` is raised. This error may be retried as defined
by :attr:`connection_error_retry_attempts`
+ :param address_remap:
+ | An optional callable which, when provided with an internal network
+ address of a node, e.g. a `(host, port)` tuple, will return the address
+ where the node is reachable. This can be used to map the addresses at
+ which the nodes _think_ they are, to addresses at which a client may
+ reach them, such as when they sit behind a proxy.
| Rest of the arguments will be passed to the
:class:`~redis.asyncio.connection.Connection` instances when created
@@ -250,6 +258,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
+ address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
if db:
raise RedisClusterException(
@@ -337,7 +346,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
- self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
+ self.nodes_manager = NodesManager(
+ startup_nodes,
+ require_full_coverage,
+ kwargs,
+ address_remap=address_remap,
+ )
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
@@ -1059,6 +1073,7 @@ class NodesManager:
"require_full_coverage",
"slots_cache",
"startup_nodes",
+ "address_remap",
)
def __init__(
@@ -1066,10 +1081,12 @@ class NodesManager:
startup_nodes: List["ClusterNode"],
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
+ address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
self.connection_kwargs = connection_kwargs
+ self.address_remap = address_remap
self.default_node: "ClusterNode" = None
self.nodes_cache: Dict[str, "ClusterNode"] = {}
@@ -1228,6 +1245,7 @@ class NodesManager:
if host == "":
host = startup_node.host
port = int(primary_node[1])
+ host, port = self.remap_host_port(host, port)
target_node = tmp_nodes_cache.get(get_node_name(host, port))
if not target_node:
@@ -1246,6 +1264,7 @@ class NodesManager:
for replica_node in replica_nodes:
host = replica_node[0]
port = replica_node[1]
+ host, port = self.remap_host_port(host, port)
target_replica_node = tmp_nodes_cache.get(
get_node_name(host, port)
@@ -1319,6 +1338,16 @@ class NodesManager:
)
)
+ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
+ """
+ Remap the host and port returned from the cluster to a different
+ internal value. Useful if the client is not connecting directly
+ to the cluster.
+ """
+ if self.address_remap:
+ return self.address_remap((host, port))
+ return host, port
+
class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
"""
diff --git a/redis/cluster.py b/redis/cluster.py
index 5e6e7da..3ecc2da 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -466,6 +466,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
read_from_replicas: bool = False,
dynamic_startup_nodes: bool = True,
url: Optional[str] = None,
+ address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
"""
@@ -514,6 +515,12 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
reinitialize_steps to 1.
To avoid reinitializing the cluster on moved errors, set
reinitialize_steps to 0.
+ :param address_remap:
+ An optional callable which, when provided with an internal network
+ address of a node, e.g. a `(host, port)` tuple, will return the address
+ where the node is reachable. This can be used to map the addresses at
+ which the nodes _think_ they are, to addresses at which a client may
+ reach them, such as when they sit behind a proxy.
:**kwargs:
Extra arguments that will be sent into Redis instance when created
@@ -594,6 +601,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
from_url=from_url,
require_full_coverage=require_full_coverage,
dynamic_startup_nodes=dynamic_startup_nodes,
+ address_remap=address_remap,
**kwargs,
)
@@ -1269,6 +1277,7 @@ class NodesManager:
lock=None,
dynamic_startup_nodes=True,
connection_pool_class=ConnectionPool,
+ address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
self.nodes_cache = {}
@@ -1280,6 +1289,7 @@ class NodesManager:
self._require_full_coverage = require_full_coverage
self._dynamic_startup_nodes = dynamic_startup_nodes
self.connection_pool_class = connection_pool_class
+ self.address_remap = address_remap
self._moved_exception = None
self.connection_kwargs = kwargs
self.read_load_balancer = LoadBalancer()
@@ -1502,6 +1512,7 @@ class NodesManager:
if host == "":
host = startup_node.host
port = int(primary_node[1])
+ host, port = self.remap_host_port(host, port)
target_node = self._get_or_create_cluster_node(
host, port, PRIMARY, tmp_nodes_cache
@@ -1518,6 +1529,7 @@ class NodesManager:
for replica_node in replica_nodes:
host = str_if_bytes(replica_node[0])
port = replica_node[1]
+ host, port = self.remap_host_port(host, port)
target_replica_node = self._get_or_create_cluster_node(
host, port, REPLICA, tmp_nodes_cache
@@ -1591,6 +1603,16 @@ class NodesManager:
# The read_load_balancer is None, do nothing
pass
+ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
+ """
+ Remap the host and port returned from the cluster to a different
+ internal value. Useful if the client is not connecting directly
+ to the cluster.
+ """
+ if self.address_remap:
+ return self.address_remap((host, port))
+ return host, port
+
class ClusterPubSub(PubSub):
"""
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 13e5e26..6d0aba7 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -11,7 +11,7 @@ import pytest_asyncio
from _pytest.fixtures import FixtureRequest
from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
-from redis.asyncio.connection import Connection, SSLConnection
+from redis.asyncio.connection import Connection, SSLConnection, async_timeout
from redis.asyncio.parser import CommandsParser
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
@@ -49,6 +49,71 @@ default_cluster_slots = [
]
+class NodeProxy:
+ """A class to proxy a node connection to a different port"""
+
+ def __init__(self, addr, redis_addr):
+ self.addr = addr
+ self.redis_addr = redis_addr
+ self.send_event = asyncio.Event()
+ self.server = None
+ self.task = None
+ self.n_connections = 0
+
+ async def start(self):
+ # test that we can connect to redis
+ async with async_timeout(2):
+ _, redis_writer = await asyncio.open_connection(*self.redis_addr)
+ redis_writer.close()
+ self.server = await asyncio.start_server(
+ self.handle, *self.addr, reuse_address=True
+ )
+ self.task = asyncio.create_task(self.server.serve_forever())
+
+ async def handle(self, reader, writer):
+ # establish connection to redis
+ redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
+ try:
+ self.n_connections += 1
+ pipe1 = asyncio.create_task(self.pipe(reader, redis_writer))
+ pipe2 = asyncio.create_task(self.pipe(redis_reader, writer))
+ await asyncio.gather(pipe1, pipe2)
+ finally:
+ redis_writer.close()
+
+ async def aclose(self):
+ self.task.cancel()
+ try:
+ await self.task
+ except asyncio.CancelledError:
+ pass
+ await self.server.wait_closed()
+
+ async def pipe(
+ self,
+ reader: asyncio.StreamReader,
+ writer: asyncio.StreamWriter,
+ ):
+ while True:
+ data = await reader.read(1000)
+ if not data:
+ break
+ writer.write(data)
+ await writer.drain()
+
+
+@pytest.fixture
+def redis_addr(request):
+ redis_url = request.config.getoption("--redis-url")
+ scheme, netloc = urlparse(redis_url)[:2]
+ assert scheme == "redis"
+ if ":" in netloc:
+ host, port = netloc.split(":")
+ return host, int(port)
+ else:
+ return netloc, 6379
+
+
@pytest_asyncio.fixture()
async def slowlog(r: RedisCluster) -> None:
"""
@@ -809,6 +874,49 @@ class TestRedisClusterObj:
# Rollback to the old default node
r.replace_default_node(curr_default_node)
+ async def test_address_remap(self, create_redis, redis_addr):
+ """Test that we can create a rediscluster object with
+ a host-port remapper and map connections through proxy objects
+ """
+
+ # we remap the first n nodes
+ offset = 1000
+ n = 6
+ ports = [redis_addr[1] + i for i in range(n)]
+
+ def address_remap(address):
+ # remap first three nodes to our local proxy
+ # old = host, port
+ host, port = address
+ if int(port) in ports:
+ host, port = "127.0.0.1", int(port) + offset
+ # print(f"{old} {host, port}")
+ return host, port
+
+ # create the proxies
+ proxies = [
+ NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
+ for port in ports
+ ]
+ await asyncio.gather(*[p.start() for p in proxies])
+ try:
+ # create cluster:
+ r = await create_redis(
+ cls=RedisCluster, flushdb=False, address_remap=address_remap
+ )
+ try:
+ assert await r.ping() is True
+ assert await r.set("byte_string", b"giraffe")
+ assert await r.get("byte_string") == b"giraffe"
+ finally:
+ await r.close()
+ finally:
+ await asyncio.gather(*[p.aclose() for p in proxies])
+
+ # verify that the proxies were indeed used
+ n_used = sum((1 if p.n_connections else 0) for p in proxies)
+ assert n_used > 1
+
class TestClusterRedisCommands:
"""
diff --git a/tests/test_cluster.py b/tests/test_cluster.py
index 58f9b77..1f037c9 100644
--- a/tests/test_cluster.py
+++ b/tests/test_cluster.py
@@ -1,9 +1,14 @@
import binascii
import datetime
+import select
+import socket
+import socketserver
+import threading
import warnings
from queue import LifoQueue, Queue
from time import sleep
from unittest.mock import DEFAULT, Mock, call, patch
+from urllib.parse import urlparse
import pytest
@@ -53,6 +58,85 @@ default_cluster_slots = [
]
+class ProxyRequestHandler(socketserver.BaseRequestHandler):
+ def recv(self, sock):
+ """A recv with a timeout"""
+ r = select.select([sock], [], [], 0.01)
+ if not r[0]:
+ return None
+ return sock.recv(1000)
+
+ def handle(self):
+ self.server.proxy.n_connections += 1
+ conn = socket.create_connection(self.server.proxy.redis_addr)
+ stop = False
+
+ def from_server():
+ # read from server and pass to client
+ while not stop:
+ data = self.recv(conn)
+ if data is None:
+ continue
+ if not data:
+ self.request.shutdown(socket.SHUT_WR)
+ return
+ self.request.sendall(data)
+
+ thread = threading.Thread(target=from_server)
+ thread.start()
+ try:
+ while True:
+ # read from client and send to server
+ data = self.request.recv(1000)
+ if not data:
+ return
+ conn.sendall(data)
+ finally:
+ conn.shutdown(socket.SHUT_WR)
+ stop = True # for safety
+ thread.join()
+ conn.close()
+
+
+class NodeProxy:
+ """A class to proxy a node connection to a different port"""
+
+ def __init__(self, addr, redis_addr):
+ self.addr = addr
+ self.redis_addr = redis_addr
+ self.server = socketserver.ThreadingTCPServer(self.addr, ProxyRequestHandler)
+ self.server.proxy = self
+ self.server.socket_reuse_address = True
+ self.thread = None
+ self.n_connections = 0
+
+ def start(self):
+ # test that we can connect to redis
+ s = socket.create_connection(self.redis_addr, timeout=2)
+ s.close()
+ # Start a thread with the server -- that thread will then start one
+ # more thread for each request
+ self.thread = threading.Thread(target=self.server.serve_forever)
+ # Exit the server thread when the main thread terminates
+ self.thread.daemon = True
+ self.thread.start()
+
+ def close(self):
+ self.server.shutdown()
+
+
+@pytest.fixture
+def redis_addr(request):
+ redis_url = request.config.getoption("--redis-url")
+ scheme, netloc = urlparse(redis_url)[:2]
+ assert scheme == "redis"
+ if ":" in netloc:
+ host, port = netloc.split(":")
+ return host, int(port)
+ else:
+ return netloc, 6379
+
+
@pytest.fixture()
def slowlog(request, r):
"""
@@ -823,6 +907,51 @@ class TestRedisClusterObj:
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
assert r.get_default_node() != curr_default_node
+ def test_address_remap(self, request, redis_addr):
+ """Test that we can create a rediscluster object with
+ a host-port remapper and map connections through proxy objects
+ """
+
+ # we remap the first n nodes
+ offset = 1000
+ n = 6
+ ports = [redis_addr[1] + i for i in range(n)]
+
+ def address_remap(address):
+ # remap first three nodes to our local proxy
+ # old = host, port
+ host, port = address
+ if int(port) in ports:
+ host, port = "127.0.0.1", int(port) + offset
+ # print(f"{old} {host, port}")
+ return host, port
+
+ # create the proxies
+ proxies = [
+ NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
+ for port in ports
+ ]
+ for p in proxies:
+ p.start()
+ try:
+ # create cluster:
+ r = _get_client(
+ RedisCluster, request, flushdb=False, address_remap=address_remap
+ )
+ try:
+ assert r.ping() is True
+ assert r.set("byte_string", b"giraffe")
+ assert r.get("byte_string") == b"giraffe"
+ finally:
+ r.close()
+ finally:
+ for p in proxies:
+ p.close()
+
+ # verify that the proxies were indeed used
+ n_used = sum((1 if p.n_connections else 0) for p in proxies)
+ assert n_used > 1
+
@pytest.mark.onlycluster
class TestClusterRedisCommands: