summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordvora-h <67596500+dvora-h@users.noreply.github.com>2023-04-24 15:49:27 +0300
committerGitHub <noreply@github.com>2023-04-24 15:49:27 +0300
commita96a38a0bb5aa05f22ad6fa3a3f5235e70b46ee3 (patch)
tree3f4e98de9020c7c3ef11b6a987133ba96ce042c5
parent0db4ebad9c47e2bcf509ae5320c94944ceb48124 (diff)
downloadredis-py-a96a38a0bb5aa05f22ad6fa3a3f5235e70b46ee3.tar.gz
Add support for PubSub with RESP3 parser (#2721)
* add resp3 pubsub * linters * _set_info_logger func * async pubsun * docstring
-rw-r--r--redis/asyncio/client.py20
-rw-r--r--redis/asyncio/connection.py16
-rwxr-xr-xredis/client.py16
-rw-r--r--redis/connection.py12
-rw-r--r--redis/parsers/resp3.py81
-rw-r--r--redis/utils.py14
-rw-r--r--tests/test_asyncio/test_pubsub.py31
-rw-r--r--tests/test_pubsub.py37
8 files changed, 197 insertions, 30 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index ffd68c1..5ef1f32 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -57,7 +57,7 @@ from redis.exceptions import (
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
-from redis.utils import safe_str, str_if_bytes
+from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
_KeyT = TypeVar("_KeyT", bound=KeyT)
@@ -658,6 +658,7 @@ class PubSub:
shard_hint: Optional[str] = None,
ignore_subscribe_messages: bool = False,
encoder=None,
+ push_handler_func: Optional[Callable] = None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
@@ -666,6 +667,7 @@ class PubSub:
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
+ self.push_handler_func = push_handler_func
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
if self.encoder.decode_responses:
@@ -678,6 +680,8 @@ class PubSub:
b"pong",
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
+ if self.push_handler_func is None:
+ _set_info_logger()
self.channels = {}
self.pending_unsubscribe_channels = set()
self.patterns = {}
@@ -757,6 +761,8 @@ class PubSub:
self.connection.register_connect_callback(self.on_connect)
else:
await self.connection.connect()
+ if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
+ self.connection._parser.set_push_handler(self.push_handler_func)
async def _disconnect_raise_connect(self, conn, error):
"""
@@ -797,7 +803,9 @@ class PubSub:
await conn.connect()
read_timeout = None if block else timeout
- response = await self._execute(conn, conn.read_response, timeout=read_timeout)
+ response = await self._execute(
+ conn, conn.read_response, timeout=read_timeout, push_request=True
+ )
if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
@@ -927,8 +935,8 @@ class PubSub:
"""
Ping the Redis server
"""
- message = "" if message is None else message
- return self.execute_command("PING", message)
+ args = ["PING", message] if message is not None else ["PING"]
+ return self.execute_command(*args)
async def handle_message(self, response, ignore_subscribe_messages=False):
"""
@@ -936,6 +944,10 @@ class PubSub:
with a message handler, the handler is invoked instead of a parsed
message being returned.
"""
+ if response is None:
+ return None
+ if isinstance(response, bytes):
+ response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
message_type = str_if_bytes(response[0])
if message_type == "pmessage":
message = {
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index d9c9583..bc872ff 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -485,15 +485,29 @@ class Connection:
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
+ push_request: Optional[bool] = False,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
try:
- if read_timeout is not None:
+ if (
+ read_timeout is not None
+ and self.protocol == "3"
+ and not HIREDIS_AVAILABLE
+ ):
+ async with async_timeout(read_timeout):
+ response = await self._parser.read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
+ elif read_timeout is not None:
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
+ elif self.protocol == "3" and not HIREDIS_AVAILABLE:
+ response = await self._parser.read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
else:
response = await self._parser.read_response(
disable_decoding=disable_decoding
diff --git a/redis/client.py b/redis/client.py
index 15dddc9..71048f5 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -27,7 +27,7 @@ from redis.exceptions import (
)
from redis.lock import Lock
from redis.retry import Retry
-from redis.utils import safe_str, str_if_bytes
+from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes
SYM_EMPTY = b""
EMPTY_RESPONSE = "EMPTY_RESPONSE"
@@ -1429,6 +1429,7 @@ class PubSub:
shard_hint=None,
ignore_subscribe_messages=False,
encoder=None,
+ push_handler_func=None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
@@ -1438,6 +1439,7 @@ class PubSub:
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
+ self.push_handler_func = push_handler_func
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
@@ -1445,6 +1447,8 @@ class PubSub:
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
else:
self.health_check_response = [b"pong", self.health_check_response_b]
+ if self.push_handler_func is None:
+ _set_info_logger()
self.reset()
def __enter__(self):
@@ -1515,6 +1519,8 @@ class PubSub:
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
+ if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
+ self.connection._parser.set_push_handler(self.push_handler_func)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
@@ -1580,7 +1586,7 @@ class PubSub:
return None
else:
conn.connect()
- return conn.read_response()
+ return conn.read_response(push_request=True)
response = self._execute(conn, try_read)
@@ -1739,8 +1745,8 @@ class PubSub:
"""
Ping the Redis server
"""
- message = "" if message is None else message
- return self.execute_command("PING", message)
+ args = ["PING", message] if message is not None else ["PING"]
+ return self.execute_command(*args)
def handle_message(self, response, ignore_subscribe_messages=False):
"""
@@ -1750,6 +1756,8 @@ class PubSub:
"""
if response is None:
return None
+ if isinstance(response, bytes):
+ response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
message_type = str_if_bytes(response[0])
if message_type == "pmessage":
message = {
diff --git a/redis/connection.py b/redis/connection.py
index 85509f7..19c80e0 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -406,13 +406,18 @@ class AbstractConnection:
self.disconnect()
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
- def read_response(self, disable_decoding=False):
+ def read_response(self, disable_decoding=False, push_request=False):
"""Read the response from a previously sent command"""
host_error = self._host_error()
try:
- response = self._parser.read_response(disable_decoding=disable_decoding)
+ if self.protocol == "3" and not HIREDIS_AVAILABLE:
+ response = self._parser.read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
+ else:
+ response = self._parser.read_response(disable_decoding=disable_decoding)
except socket.timeout:
self.disconnect()
raise TimeoutError(f"Timeout reading from {host_error}")
@@ -705,8 +710,9 @@ class SSLConnection(Connection):
class UnixDomainSocketConnection(AbstractConnection):
"Manages UDS communication to and from a Redis server"
- def __init__(self, path="", **kwargs):
+ def __init__(self, path="", socket_timeout=None, **kwargs):
self.path = path
+ self.socket_timeout = socket_timeout
super().__init__(**kwargs)
def repr_pieces(self):
diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py
index 2753d39..93fb6ff 100644
--- a/redis/parsers/resp3.py
+++ b/redis/parsers/resp3.py
@@ -1,3 +1,4 @@
+from logging import getLogger
from typing import Any, Union
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
@@ -9,10 +10,21 @@ from .socket import SERVER_CLOSED_CONNECTION_ERROR
class _RESP3Parser(_RESPBase):
"""RESP3 protocol implementation"""
- def read_response(self, disable_decoding=False):
+ def __init__(self, socket_read_size):
+ super().__init__(socket_read_size)
+ self.push_handler_func = self.handle_push_response
+
+ def handle_push_response(self, response):
+ logger = getLogger("push_response")
+ logger.info("Push response: " + str(response))
+ return response
+
+ def read_response(self, disable_decoding=False, push_request=False):
pos = self._buffer.get_pos()
try:
- result = self._read_response(disable_decoding=disable_decoding)
+ result = self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
except BaseException:
self._buffer.rewind(pos)
raise
@@ -20,7 +32,7 @@ class _RESP3Parser(_RESPBase):
self._buffer.purge()
return result
- def _read_response(self, disable_decoding=False):
+ def _read_response(self, disable_decoding=False, push_request=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -77,9 +89,26 @@ class _RESP3Parser(_RESPBase):
response = {
self._read_response(
disable_decoding=disable_decoding
- ): self._read_response(disable_decoding=disable_decoding)
+ ): self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
for _ in range(int(response))
}
+ # push response
+ elif byte == b">":
+ response = [
+ self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
+ for _ in range(int(response))
+ ]
+ res = self.push_handler_func(response)
+ if not push_request:
+ return self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
+ else:
+ return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
@@ -87,21 +116,37 @@ class _RESP3Parser(_RESPBase):
response = self.encoder.decode(response)
return response
+ def set_push_handler(self, push_handler_func):
+ self.push_handler_func = push_handler_func
+
class _AsyncRESP3Parser(_AsyncRESPBase):
- async def read_response(self, disable_decoding: bool = False):
+ def __init__(self, socket_read_size):
+ super().__init__(socket_read_size)
+ self.push_handler_func = self.handle_push_response
+
+ def handle_push_response(self, response):
+ logger = getLogger("push_response")
+ logger.info("Push response: " + str(response))
+ return response
+
+ async def read_response(
+ self, disable_decoding: bool = False, push_request: bool = False
+ ):
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
- response = await self._read_response(disable_decoding=disable_decoding)
+ response = await self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response
async def _read_response(
- self, disable_decoding: bool = False
+ self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -166,9 +211,31 @@ class _AsyncRESP3Parser(_AsyncRESPBase):
)
for _ in range(int(response))
}
+ # push response
+ elif byte == b">":
+ response = [
+ (
+ await self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
+ )
+ for _ in range(int(response))
+ ]
+ res = self.push_handler_func(response)
+ if not push_request:
+ return await (
+ self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
+ )
+ else:
+ return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response
+
+ def set_push_handler(self, push_handler_func):
+ self.push_handler_func = push_handler_func
diff --git a/redis/utils.py b/redis/utils.py
index a6e6200..148d152 100644
--- a/redis/utils.py
+++ b/redis/utils.py
@@ -1,3 +1,4 @@
+import logging
from contextlib import contextmanager
from functools import wraps
from typing import Any, Dict, Mapping, Union
@@ -117,3 +118,16 @@ def deprecated_function(reason="", version="", name=None):
return wrapper
return decorator
+
+
+def _set_info_logger():
+ """
+ Set up a logger that log info logs to stdout.
+ (This is used by the default push response handler)
+ """
+ if "push_response" not in logging.root.manager.loggerDict.keys():
+ logger = logging.getLogger("push_response")
+ logger.setLevel(logging.INFO)
+ handler = logging.StreamHandler()
+ handler.setLevel(logging.INFO)
+ logger.addHandler(handler)
diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py
index 0c0b7db..8cd5cf6 100644
--- a/tests/test_asyncio/test_pubsub.py
+++ b/tests/test_asyncio/test_pubsub.py
@@ -16,9 +16,11 @@ import pytest_asyncio
import redis.asyncio as redis
from redis.exceptions import ConnectionError
from redis.typing import EncodableT
+from redis.utils import HIREDIS_AVAILABLE
from tests.conftest import skip_if_server_version_lt
from .compat import create_task, mock
+from .conftest import get_protocol_version
def with_timeout(t):
@@ -420,6 +422,23 @@ class TestPubSubMessages:
assert expect in info.exconly()
+class TestPubSubRESP3Handler:
+ def my_handler(self, message):
+ self.message = ["my handler", message]
+
+ @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
+ async def test_push_handler(self, r):
+ if get_protocol_version(r) in [2, "2", None]:
+ return
+ p = r.pubsub(push_handler_func=self.my_handler)
+ await p.subscribe("foo")
+ assert await wait_for_message(p) is None
+ assert self.message == ["my handler", [b"subscribe", b"foo", 1]]
+ assert await r.publish("foo", "test message") == 1
+ assert await wait_for_message(p) is None
+ assert self.message == ["my handler", [b"message", b"foo", b"test message"]]
+
+
@pytest.mark.onlynoncluster
class TestPubSubAutoDecoding:
"""These tests only validate that we get unicode values back"""
@@ -995,13 +1014,15 @@ class TestBaseException:
assert msg is not None
# timeout waiting for another message which never arrives
assert pubsub.connection.is_connected
- with patch("redis.parsers._AsyncRESP2Parser.read_response") as mock1:
+ with patch("redis.parsers._AsyncRESP2Parser.read_response") as mock1, patch(
+ "redis.parsers._AsyncHiredisParser.read_response"
+ ) as mock2, patch("redis.parsers._AsyncRESP3Parser.read_response") as mock3:
mock1.side_effect = BaseException("boom")
- with patch("redis.parsers._AsyncHiredisParser.read_response") as mock2:
- mock2.side_effect = BaseException("boom")
+ mock2.side_effect = BaseException("boom")
+ mock3.side_effect = BaseException("boom")
- with pytest.raises(BaseException):
- await get_msg()
+ with pytest.raises(BaseException):
+ await get_msg()
# the timeout on the read should not cause disconnect
assert pubsub.connection.is_connected
diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py
index 48c0f3a..e1e4311 100644
--- a/tests/test_pubsub.py
+++ b/tests/test_pubsub.py
@@ -10,8 +10,14 @@ import pytest
import redis
from redis.exceptions import ConnectionError
+from redis.utils import HIREDIS_AVAILABLE
-from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt
+from .conftest import (
+ _get_client,
+ is_resp2_connection,
+ skip_if_redis_enterprise,
+ skip_if_server_version_lt,
+)
def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False):
@@ -352,6 +358,23 @@ class TestPubSubMessages:
)
+class TestPubSubRESP3Handler:
+ def my_handler(self, message):
+ self.message = ["my handler", message]
+
+ @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
+ def test_push_handler(self, r):
+ if is_resp2_connection(r):
+ return
+ p = r.pubsub(push_handler_func=self.my_handler)
+ p.subscribe("foo")
+ assert wait_for_message(p) is None
+ assert self.message == ["my handler", [b"subscribe", b"foo", 1]]
+ assert r.publish("foo", "test message") == 1
+ assert wait_for_message(p) is None
+ assert self.message == ["my handler", [b"message", b"foo", b"test message"]]
+
+
class TestPubSubAutoDecoding:
"These tests only validate that we get unicode values back"
@@ -767,13 +790,15 @@ class TestBaseException:
assert msg is not None
# timeout waiting for another message which never arrives
assert is_connected()
- with patch("redis.parsers._RESP2Parser.read_response") as mock1:
+ with patch("redis.parsers._RESP2Parser.read_response") as mock1, patch(
+ "redis.parsers._HiredisParser.read_response"
+ ) as mock2, patch("redis.parsers._RESP3Parser.read_response") as mock3:
mock1.side_effect = BaseException("boom")
- with patch("redis.parsers._HiredisParser.read_response") as mock2:
- mock2.side_effect = BaseException("boom")
+ mock2.side_effect = BaseException("boom")
+ mock3.side_effect = BaseException("boom")
- with pytest.raises(BaseException):
- get_msg()
+ with pytest.raises(BaseException):
+ get_msg()
# the timeout on the read should not cause disconnect
assert is_connected()