summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES1
-rw-r--r--redis/asyncio/connection.py69
-rwxr-xr-xredis/connection.py58
-rw-r--r--tests/mocks.py41
-rw-r--r--tests/test_asyncio/mocks.py51
-rw-r--r--tests/test_asyncio/test_connection.py48
-rw-r--r--tests/test_connection.py43
7 files changed, 269 insertions, 42 deletions
diff --git a/CHANGES b/CHANGES
index 228910f..fca8d31 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,4 @@
+ * Make PythonParser resumable in case of error (#2510)
* Add `timeout=None` in `SentinelConnectionManager.read_response`
* Documentation fix: password protected socket connection (#2374)
* Allow `timeout=None` in `PubSub.get_message()` to wait forever
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index d031f41..2c75d4f 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -208,11 +208,18 @@ class BaseParser:
class PythonParser(BaseParser):
"""Plain Python parsing class"""
- __slots__ = BaseParser.__slots__ + ("encoder",)
+ __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
self.encoder: Optional[Encoder] = None
+ self._buffer = b""
+ self._chunks = []
+ self._pos = 0
+
+ def _clear(self):
+ self._buffer = b""
+ self._chunks.clear()
def on_connect(self, connection: "Connection"):
"""Called when the stream connects"""
@@ -227,8 +234,11 @@ class PythonParser(BaseParser):
if self._stream is not None:
self._stream = None
self.encoder = None
+ self._clear()
async def can_read_destructive(self) -> bool:
+ if self._buffer:
+ return True
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
@@ -237,14 +247,23 @@ class PythonParser(BaseParser):
except asyncio.TimeoutError:
return False
- async def read_response(
+ async def read_response(self, disable_decoding: bool = False):
+ if self._chunks:
+ # augment parsing buffer with previously read data
+ self._buffer += b"".join(self._chunks)
+ self._chunks.clear()
+ self._pos = 0
+ response = await self._read_response(disable_decoding=disable_decoding)
+ # Successfully parsing a response allows us to clear our parsing buffer
+ self._clear()
+ return response
+
+ async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
- if not raw:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
response: Any
byte, response = raw[:1], raw[1:]
@@ -258,6 +277,7 @@ class PythonParser(BaseParser):
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
+ self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
@@ -282,7 +302,7 @@ class PythonParser(BaseParser):
if length == -1:
return None
response = [
- (await self.read_response(disable_decoding)) for _ in range(length)
+ (await self._read_response(disable_decoding)) for _ in range(length)
]
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
@@ -293,25 +313,38 @@ class PythonParser(BaseParser):
Read `length` bytes of data. These are assumed to be followed
by a '\r\n' terminator which is subsequently discarded.
"""
- if self._stream is None:
- raise RedisError("Buffer is closed.")
- try:
- data = await self._stream.readexactly(length + 2)
- except asyncio.IncompleteReadError as error:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
- return data[:-2]
+ want = length + 2
+ end = self._pos + want
+ if len(self._buffer) >= end:
+ result = self._buffer[self._pos : end - 2]
+ else:
+ tail = self._buffer[self._pos :]
+ try:
+ data = await self._stream.readexactly(want - len(tail))
+ except asyncio.IncompleteReadError as error:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
+ result = (tail + data)[:-2]
+ self._chunks.append(data)
+ self._pos += want
+ return result
async def _readline(self) -> bytes:
"""
read an unknown number of bytes up to the next '\r\n'
line separator, which is discarded.
"""
- if self._stream is None:
- raise RedisError("Buffer is closed.")
- data = await self._stream.readline()
- if not data.endswith(b"\r\n"):
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- return data[:-2]
+ found = self._buffer.find(b"\r\n", self._pos)
+ if found >= 0:
+ result = self._buffer[self._pos : found]
+ else:
+ tail = self._buffer[self._pos :]
+ data = await self._stream.readline()
+ if not data.endswith(b"\r\n"):
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ result = (tail + data)[:-2]
+ self._chunks.append(data)
+ self._pos += len(result) + 2
+ return result
class HiredisParser(BaseParser):
diff --git a/redis/connection.py b/redis/connection.py
index b810fc5..126ea5d 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -232,12 +232,6 @@ class SocketBuffer:
self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)
-
- # purge the buffer when we've consumed it all so it doesn't
- # grow forever
- if self.bytes_read == self.bytes_written:
- self.purge()
-
return data[:-2]
def readline(self):
@@ -251,23 +245,44 @@ class SocketBuffer:
data = buf.readline()
self.bytes_read += len(data)
+ return data[:-2]
- # purge the buffer when we've consumed it all so it doesn't
- # grow forever
- if self.bytes_read == self.bytes_written:
- self.purge()
+ def get_pos(self):
+ """
+ Get current read position
+ """
+ return self.bytes_read
- return data[:-2]
+ def rewind(self, pos):
+ """
+ Rewind the buffer to a specific position, to re-start reading
+ """
+ self.bytes_read = pos
def purge(self):
- self._buffer.seek(0)
- self._buffer.truncate()
- self.bytes_written = 0
+ """
+ After a successful read, purge the read part of buffer
+ """
+ unread = self.bytes_written - self.bytes_read
+
+ # Only if we have read all of the buffer do we truncate, to
+ # reduce the amount of memory thrashing. This heuristic
+ # can be changed or removed later.
+ if unread > 0:
+ return
+
+ if unread > 0:
+ # move unread data to the front
+ view = self._buffer.getbuffer()
+ view[:unread] = view[-unread:]
+ self._buffer.truncate(unread)
+ self.bytes_written = unread
self.bytes_read = 0
+ self._buffer.seek(0)
def close(self):
try:
- self.purge()
+ self.bytes_written = self.bytes_read = 0
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
@@ -315,6 +330,17 @@ class PythonParser(BaseParser):
return self._buffer and self._buffer.can_read(timeout)
def read_response(self, disable_decoding=False):
+ pos = self._buffer.get_pos()
+ try:
+ result = self._read_response(disable_decoding=disable_decoding)
+ except BaseException:
+ self._buffer.rewind(pos)
+ raise
+ else:
+ self._buffer.purge()
+ return result
+
+ def _read_response(self, disable_decoding=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -355,7 +381,7 @@ class PythonParser(BaseParser):
if length == -1:
return None
response = [
- self.read_response(disable_decoding=disable_decoding)
+ self._read_response(disable_decoding=disable_decoding)
for i in range(length)
]
if isinstance(response, bytes) and disable_decoding is False:
diff --git a/tests/mocks.py b/tests/mocks.py
new file mode 100644
index 0000000..d7d450e
--- /dev/null
+++ b/tests/mocks.py
@@ -0,0 +1,41 @@
+# Various mocks for testing
+
+
+class MockSocket:
+ """
+ A class simulating an readable socket, optionally raising a
+ special exception every other read.
+ """
+
+ class TestError(BaseException):
+ pass
+
+ def __init__(self, data, interrupt_every=0):
+ self.data = data
+ self.counter = 0
+ self.pos = 0
+ self.interrupt_every = interrupt_every
+
+ def tick(self):
+ self.counter += 1
+ if not self.interrupt_every:
+ return
+ if (self.counter % self.interrupt_every) == 0:
+ raise self.TestError()
+
+ def recv(self, bufsize):
+ self.tick()
+ bufsize = min(5, bufsize) # truncate the read size
+ result = self.data[self.pos : self.pos + bufsize]
+ self.pos += len(result)
+ return result
+
+ def recv_into(self, buffer, nbytes=0, flags=0):
+ self.tick()
+ if nbytes == 0:
+ nbytes = len(buffer)
+ nbytes = min(5, nbytes) # truncate the read size
+ result = self.data[self.pos : self.pos + nbytes]
+ self.pos += len(result)
+ buffer[: len(result)] = result
+ return len(result)
diff --git a/tests/test_asyncio/mocks.py b/tests/test_asyncio/mocks.py
new file mode 100644
index 0000000..89bd9c0
--- /dev/null
+++ b/tests/test_asyncio/mocks.py
@@ -0,0 +1,51 @@
+import asyncio
+
+# Helper Mocking classes for the tests.
+
+
+class MockStream:
+ """
+ A class simulating an asyncio input buffer, optionally raising a
+ special exception every other read.
+ """
+
+ class TestError(BaseException):
+ pass
+
+ def __init__(self, data, interrupt_every=0):
+ self.data = data
+ self.counter = 0
+ self.pos = 0
+ self.interrupt_every = interrupt_every
+
+ def tick(self):
+ self.counter += 1
+ if not self.interrupt_every:
+ return
+ if (self.counter % self.interrupt_every) == 0:
+ raise self.TestError()
+
+ async def read(self, want):
+ self.tick()
+ want = 5
+ result = self.data[self.pos : self.pos + want]
+ self.pos += len(result)
+ return result
+
+ async def readline(self):
+ self.tick()
+ find = self.data.find(b"\n", self.pos)
+ if find >= 0:
+ result = self.data[self.pos : find + 1]
+ else:
+ result = self.data[self.pos :]
+ self.pos += len(result)
+ return result
+
+ async def readexactly(self, length):
+ self.tick()
+ result = self.data[self.pos : self.pos + length]
+ if len(result) < length:
+ raise asyncio.IncompleteReadError(result, None)
+ self.pos += len(result)
+ return result
diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py
index 6bf0034..bf59dbe 100644
--- a/tests/test_asyncio/test_connection.py
+++ b/tests/test_asyncio/test_connection.py
@@ -5,7 +5,9 @@ from unittest.mock import patch
import pytest
+import redis
from redis.asyncio.connection import (
+ BaseParser,
Connection,
PythonParser,
UnixDomainSocketConnection,
@@ -16,6 +18,7 @@ from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
from tests.conftest import skip_if_server_version_lt
from .compat import mock
+from .mocks import MockStream
@pytest.mark.onlynoncluster
@@ -23,16 +26,19 @@ async def test_invalid_response(create_redis):
r = await create_redis(single_connection_client=True)
raw = b"x"
+ fake_stream = MockStream(raw + b"\r\n")
- parser: "PythonParser" = r.connection._parser
- if not isinstance(parser, PythonParser):
- pytest.skip("PythonParser only")
- stream_mock = mock.Mock(parser._stream)
- stream_mock.readline.return_value = raw + b"\r\n"
- with mock.patch.object(parser, "_stream", stream_mock):
+ parser: BaseParser = r.connection._parser
+ with mock.patch.object(parser, "_stream", fake_stream):
with pytest.raises(InvalidResponse) as cm:
await parser.read_response()
- assert str(cm.value) == f"Protocol Error: {raw!r}"
+ if isinstance(parser, PythonParser):
+ assert str(cm.value) == f"Protocol Error: {raw!r}"
+ else:
+ assert (
+ str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte'
+ )
+ await r.connection.disconnect()
@skip_if_server_version_lt("4.0.0")
@@ -112,3 +118,31 @@ async def test_connect_timeout_error_without_retry():
await conn.connect()
assert conn._connect.call_count == 1
assert str(e.value) == "Timeout connecting to server"
+
+
+@pytest.mark.onlynoncluster
+async def test_connection_parse_response_resume(r: redis.Redis):
+ """
+ This test verifies that the Connection parser,
+ be that PythonParser or HiredisParser,
+ can be interrupted at IO time and then resume parsing.
+ """
+ conn = Connection(**r.connection_pool.connection_kwargs)
+ await conn.connect()
+ message = (
+ b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
+ b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
+ )
+
+ conn._parser._stream = MockStream(message, interrupt_every=2)
+ for i in range(100):
+ try:
+ response = await conn.read_response()
+ break
+ except MockStream.TestError:
+ pass
+
+ else:
+ pytest.fail("didn't receive a response")
+ assert response
+ assert i > 0
diff --git a/tests/test_connection.py b/tests/test_connection.py
index d9251c3..e0b53cd 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -5,13 +5,15 @@ from unittest.mock import patch
import pytest
+import redis
from redis.backoff import NoBackoff
-from redis.connection import Connection
+from redis.connection import Connection, HiredisParser, PythonParser
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
from redis.retry import Retry
from redis.utils import HIREDIS_AVAILABLE
from .conftest import skip_if_server_version_lt
+from .mocks import MockSocket
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@@ -122,3 +124,42 @@ class TestConnection:
assert conn._connect.call_count == 1
assert str(e.value) == "Timeout connecting to server"
self.clear(conn)
+
+
+@pytest.mark.onlynoncluster
+@pytest.mark.parametrize(
+ "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"]
+)
+def test_connection_parse_response_resume(r: redis.Redis, parser_class):
+ """
+ This test verifies that the Connection parser,
+ be that PythonParser or HiredisParser,
+ can be interrupted at IO time and then resume parsing.
+ """
+ if parser_class is HiredisParser and not HIREDIS_AVAILABLE:
+ pytest.skip("Hiredis not available)")
+ args = dict(r.connection_pool.connection_kwargs)
+ args["parser_class"] = parser_class
+ conn = Connection(**args)
+ conn.connect()
+ message = (
+ b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
+ b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
+ )
+ mock_socket = MockSocket(message, interrupt_every=2)
+
+ if isinstance(conn._parser, PythonParser):
+ conn._parser._buffer._sock = mock_socket
+ else:
+ conn._parser._sock = mock_socket
+ for i in range(100):
+ try:
+ response = conn.read_response()
+ break
+ except MockSocket.TestError:
+ pass
+
+ else:
+ pytest.fail("didn't receive a response")
+ assert response
+ assert i > 0