summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristján Valur Jónsson <sweskman@gmail.com>2023-05-08 10:11:43 +0000
committerGitHub <noreply@github.com>2023-05-08 13:11:43 +0300
commitc0833f60a1d9ec85c589004aba6b6739e6298248 (patch)
tree9fbe069b992e8a2ff301ebce4722d22c9e5d8e80
parent093232d8b4cecaac5d8b15c908bd0f8f73927238 (diff)
downloadredis-py-c0833f60a1d9ec85c589004aba6b6739e6298248.tar.gz
Optionally disable disconnects in read_response (#2695)
* Add regression tests and fixes for issue #1128 * Fix tests for resumable read_response to use "disconnect_on_error" * undo prevision fix attempts in async client and cluster * re-enable cluster test * Suggestions from code review * Add CHANGES
-rw-r--r--CHANGES1
-rw-r--r--redis/asyncio/client.py93
-rw-r--r--redis/asyncio/cluster.py33
-rw-r--r--redis/asyncio/connection.py28
-rwxr-xr-xredis/client.py2
-rw-r--r--redis/connection.py24
-rw-r--r--tests/test_asyncio/test_commands.py38
-rw-r--r--tests/test_asyncio/test_connection.py2
-rw-r--r--tests/test_asyncio/test_cwe_404.py1
-rw-r--r--tests/test_commands.py35
-rw-r--r--tests/test_connection.py2
11 files changed, 149 insertions, 110 deletions
diff --git a/CHANGES b/CHANGES
index 3865ed1..ea171f6 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,4 @@
+ * Revert #2104, #2673, add `disconnect_on_error` option to `read_response()` (issues #2506, #2624)
* Add `address_remap` parameter to `RedisCluster`
* Fix incorrect usage of once flag in async Sentinel
* asyncio: Fix memory leak caused by hiredis (#2693)
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index 5fb94b3..a7b888e 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -500,23 +500,6 @@ class Redis(
):
raise error
- async def _try_send_command_parse_response(self, conn, *args, **options):
- try:
- return await conn.retry.call_with_retry(
- lambda: self._send_command_parse_response(
- conn, args[0], *args, **options
- ),
- lambda error: self._disconnect_raise(conn, error),
- )
- except asyncio.CancelledError:
- await conn.disconnect(nowait=True)
- raise
- finally:
- if self.single_connection_client:
- self._single_conn_lock.release()
- if not self.connection:
- await self.connection_pool.release(conn)
-
# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
@@ -527,10 +510,18 @@ class Redis(
if self.single_connection_client:
await self._single_conn_lock.acquire()
-
- return await asyncio.shield(
- self._try_send_command_parse_response(conn, *args, **options)
- )
+ try:
+ return await conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(
+ conn, command_name, *args, **options
+ ),
+ lambda error: self._disconnect_raise(conn, error),
+ )
+ finally:
+ if self.single_connection_client:
+ self._single_conn_lock.release()
+ if not self.connection:
+ await pool.release(conn)
async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
@@ -774,18 +765,10 @@ class PubSub:
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()
-
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()
- async def _try_execute(self, conn, command, *arg, **kwargs):
- try:
- return await command(*arg, **kwargs)
- except asyncio.CancelledError:
- await conn.disconnect()
- raise
-
async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
@@ -794,11 +777,9 @@ class PubSub:
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
- return await asyncio.shield(
- conn.retry.call_with_retry(
- lambda: self._try_execute(conn, command, *args, **kwargs),
- lambda error: self._disconnect_raise_connect(conn, error),
- )
+ return await conn.retry.call_with_retry(
+ lambda: command(*args, **kwargs),
+ lambda error: self._disconnect_raise_connect(conn, error),
)
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -816,7 +797,9 @@ class PubSub:
await conn.connect()
read_timeout = None if block else timeout
- response = await self._execute(conn, conn.read_response, timeout=read_timeout)
+ response = await self._execute(
+ conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False
+ )
if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
@@ -1200,18 +1183,6 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
await self.reset()
raise
- async def _try_send_command_parse_response(self, conn, *args, **options):
- try:
- return await conn.retry.call_with_retry(
- lambda: self._send_command_parse_response(
- conn, args[0], *args, **options
- ),
- lambda error: self._disconnect_reset_raise(conn, error),
- )
- except asyncio.CancelledError:
- await conn.disconnect()
- raise
-
async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
@@ -1227,8 +1198,12 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
command_name, self.shard_hint
)
self.connection = conn
- return await asyncio.shield(
- self._try_send_command_parse_response(conn, *args, **options)
+
+ return await conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(
+ conn, command_name, *args, **options
+ ),
+ lambda error: self._disconnect_reset_raise(conn, error),
)
def pipeline_execute_command(self, *args, **options):
@@ -1396,19 +1371,6 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
await self.reset()
raise
- async def _try_execute(self, conn, execute, stack, raise_on_error):
- try:
- return await conn.retry.call_with_retry(
- lambda: execute(conn, stack, raise_on_error),
- lambda error: self._disconnect_raise_reset(conn, error),
- )
- except asyncio.CancelledError:
- # not supposed to be possible, yet here we are
- await conn.disconnect(nowait=True)
- raise
- finally:
- await self.reset()
-
async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
@@ -1430,11 +1392,10 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
conn = cast(Connection, conn)
try:
- return await asyncio.shield(
- self._try_execute(conn, execute, stack, raise_on_error)
+ return await conn.retry.call_with_retry(
+ lambda: execute(conn, stack, raise_on_error),
+ lambda error: self._disconnect_raise_reset(conn, error),
)
- except RuntimeError:
- await self.reset()
finally:
await self.reset()
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index eb5f4db..929d3e4 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -1016,33 +1016,12 @@ class ClusterNode:
await connection.send_packed_command(connection.pack_command(*args), False)
# Read response
- return await asyncio.shield(
- self._parse_and_release(connection, args[0], **kwargs)
- )
-
- async def _parse_and_release(self, connection, *args, **kwargs):
try:
- return await self.parse_response(connection, *args, **kwargs)
- except asyncio.CancelledError:
- # should not be possible
- await connection.disconnect(nowait=True)
- raise
+ return await self.parse_response(connection, args[0], **kwargs)
finally:
+ # Release connection
self._free.append(connection)
- async def _try_parse_response(self, cmd, connection, ret):
- try:
- cmd.result = await asyncio.shield(
- self.parse_response(connection, cmd.args[0], **cmd.kwargs)
- )
- except asyncio.CancelledError:
- await connection.disconnect(nowait=True)
- raise
- except Exception as e:
- cmd.result = e
- ret = True
- return ret
-
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
@@ -1055,7 +1034,13 @@ class ClusterNode:
# Read responses
ret = False
for cmd in commands:
- ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
+ try:
+ cmd.result = await self.parse_response(
+ connection, cmd.args[0], **cmd.kwargs
+ )
+ except Exception as e:
+ cmd.result = e
+ ret = True
# Release connection
self._free.append(connection)
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 59f75aa..462673f 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -804,7 +804,11 @@ class Connection:
raise ConnectionError(
f"Error {err_no} while writing to socket. {errmsg}."
) from e
- except Exception:
+ except BaseException:
+ # BaseExceptions can be raised when a socket send operation is not
+ # finished, e.g. due to a timeout. Ideally, a caller could then re-try
+ # to send un-sent data. However, the send_packed_command() API
+ # does not support it so there is no point in keeping the connection open.
await self.disconnect(nowait=True)
raise
@@ -828,6 +832,8 @@ class Connection:
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
+ *,
+ disconnect_on_error: bool = True,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
@@ -843,22 +849,24 @@ class Connection:
)
except asyncio.TimeoutError:
if timeout is not None:
- # user requested timeout, return None
+ # user requested timeout, return None. Operation can be retried
return None
# it was a self.socket_timeout error.
- await self.disconnect(nowait=True)
+ if disconnect_on_error:
+ await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
- await self.disconnect(nowait=True)
+ if disconnect_on_error:
+ await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
- except asyncio.CancelledError:
- # need this check for 3.7, where CancelledError
- # is subclass of Exception, not BaseException
- raise
- except Exception:
- await self.disconnect(nowait=True)
+ except BaseException:
+ # Also by default close in case of BaseException. A lot of code
+ # relies on this behaviour when doing Command/Response pairs.
+ # See #1128.
+ if disconnect_on_error:
+ await self.disconnect(nowait=True)
raise
if self.health_check_interval:
diff --git a/redis/client.py b/redis/client.py
index c43a388..65d0cec 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -1533,7 +1533,7 @@ class PubSub:
return None
else:
conn.connect()
- return conn.read_response()
+ return conn.read_response(disconnect_on_error=False)
response = self._execute(conn, try_read)
diff --git a/redis/connection.py b/redis/connection.py
index 8b2389c..5af8928 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -834,7 +834,11 @@ class AbstractConnection:
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
- except Exception:
+ except BaseException:
+ # BaseExceptions can be raised when a socket send operation is not
+ # finished, e.g. due to a timeout. Ideally, a caller could then re-try
+ # to send un-sent data. However, the send_packed_command() API
+ # does not support it so there is no point in keeping the connection open.
self.disconnect()
raise
@@ -859,7 +863,9 @@ class AbstractConnection:
self.disconnect()
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
- def read_response(self, disable_decoding=False):
+ def read_response(
+ self, disable_decoding=False, *, disconnect_on_error: bool = True
+ ):
"""Read the response from a previously sent command"""
host_error = self._host_error()
@@ -867,15 +873,21 @@ class AbstractConnection:
try:
response = self._parser.read_response(disable_decoding=disable_decoding)
except socket.timeout:
- self.disconnect()
+ if disconnect_on_error:
+ self.disconnect()
raise TimeoutError(f"Timeout reading from {host_error}")
except OSError as e:
- self.disconnect()
+ if disconnect_on_error:
+ self.disconnect()
raise ConnectionError(
f"Error while reading from {host_error}" f" : {e.args}"
)
- except Exception:
- self.disconnect()
+ except BaseException:
+ # Also by default close in case of BaseException. A lot of code
+ # relies on this behaviour when doing Command/Response pairs.
+ # See #1128.
+ if disconnect_on_error:
+ self.disconnect()
raise
if self.health_check_interval:
diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py
index 409934c..ac3537d 100644
--- a/tests/test_asyncio/test_commands.py
+++ b/tests/test_asyncio/test_commands.py
@@ -1,9 +1,11 @@
"""
Tests async overrides of commands from their mixins
"""
+import asyncio
import binascii
import datetime
import re
+import sys
from string import ascii_letters
import pytest
@@ -18,6 +20,11 @@ from tests.conftest import (
skip_unless_arch_bits,
)
+if sys.version_info >= (3, 11, 3):
+ from asyncio import timeout as async_timeout
+else:
+ from async_timeout import timeout as async_timeout
+
REDIS_6_VERSION = "5.9.0"
@@ -3008,6 +3015,37 @@ class TestRedisCommands:
for x in await r.module_list():
assert isinstance(x, dict)
+ @pytest.mark.onlynoncluster
+ async def test_interrupted_command(self, r: redis.Redis):
+ """
+ Regression test for issue #1128: An Un-handled BaseException
+ will leave the socket with un-read response to a previous
+ command.
+ """
+ ready = asyncio.Event()
+
+ async def helper():
+ with pytest.raises(asyncio.CancelledError):
+ # blocking pop
+ ready.set()
+ await r.brpop(["nonexist"])
+ # If the following is not done, further Timout operations will fail,
+ # because the timeout won't catch its Cancelled Error if the task
+ # has a pending cancel. Python documentation probably should reflect this.
+ if sys.version_info >= (3, 11):
+ asyncio.current_task().uncancel()
+ # if all is well, we can continue. The following should not hang.
+ await r.set("status", "down")
+
+ task = asyncio.create_task(helper())
+ await ready.wait()
+ await asyncio.sleep(0.01)
+ # the task is now sleeping, lets send it an exception
+ task.cancel()
+ # If all is well, the task should finish right away, otherwise fail with Timeout
+ async with async_timeout(0.1):
+ await task
+
@pytest.mark.onlynoncluster
class TestBinarySave:
diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py
index e2d77fc..e49dd42 100644
--- a/tests/test_asyncio/test_connection.py
+++ b/tests/test_asyncio/test_connection.py
@@ -184,7 +184,7 @@ async def test_connection_parse_response_resume(r: redis.Redis):
conn._parser._stream = MockStream(message, interrupt_every=2)
for i in range(100):
try:
- response = await conn.read_response()
+ response = await conn.read_response(disconnect_on_error=False)
break
except MockStream.TestError:
pass
diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py
index d3a0666..21f2ddd 100644
--- a/tests/test_asyncio/test_cwe_404.py
+++ b/tests/test_asyncio/test_cwe_404.py
@@ -128,7 +128,6 @@ async def test_standalone(delay, master_host):
assert await r.get("foo") == b"foo"
-@pytest.mark.xfail(reason="cancel does not cause disconnect")
@pytest.mark.onlynoncluster
@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
async def test_standalone_pipeline(delay, master_host):
diff --git a/tests/test_commands.py b/tests/test_commands.py
index 4020f5e..cb89669 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -1,9 +1,12 @@
import binascii
import datetime
import re
+import threading
import time
+from asyncio import CancelledError
from string import ascii_letters
from unittest import mock
+from unittest.mock import patch
import pytest
@@ -4743,6 +4746,38 @@ class TestRedisCommands:
res = r2.psync(r2.client_id(), 1)
assert b"FULLRESYNC" in res
+ @pytest.mark.onlynoncluster
+ def test_interrupted_command(self, r: redis.Redis):
+ """
+ Regression test for issue #1128: An Un-handled BaseException
+ will leave the socket with un-read response to a previous
+ command.
+ """
+
+ ok = False
+
+ def helper():
+ with pytest.raises(CancelledError):
+ # blocking pop
+ with patch.object(
+ r.connection._parser, "read_response", side_effect=CancelledError
+ ):
+ r.brpop(["nonexist"])
+ # if all is well, we can continue.
+ r.set("status", "down") # should not hang
+ nonlocal ok
+ ok = True
+
+ thread = threading.Thread(target=helper)
+ thread.start()
+ thread.join(0.1)
+ try:
+ assert not thread.is_alive()
+ assert ok
+ finally:
+ # disconnect here so that fixture cleanup can proceed
+ r.connection.disconnect()
+
@pytest.mark.onlynoncluster
class TestBinarySave:
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 25b4118..75ba738 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -160,7 +160,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class):
conn._parser._sock = mock_socket
for i in range(100):
try:
- response = conn.read_response()
+ response = conn.read_response(disconnect_on_error=False)
break
except MockSocket.TestError:
pass