summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChayim <chayim@users.noreply.github.com>2023-03-29 12:01:45 +0300
committerdvora-h <dvora.heller@redis.com>2023-03-29 12:44:43 +0300
commit6cd5173b29f80adfefa7c6350bdd8cd6c98ac94f (patch)
tree6190b0f047810edea02cf83ed728c2707eef33de
parent7b48b1bb34f97e4adf32c12a987ff593dc536aef (diff)
downloadredis-py-6cd5173b29f80adfefa7c6350bdd8cd6c98ac94f.tar.gz
Fixing cancelled async futures (#2666)
Co-authored-by: James R T <jamestiotio@gmail.com> Co-authored-by: dvora-h <dvora.heller@redis.com>
-rw-r--r--redis/asyncio/client.py94
-rw-r--r--redis/asyncio/cluster.py21
-rw-r--r--tests/test_asyncio/test_cluster.py17
-rw-r--r--tests/test_asyncio/test_connection.py23
-rw-r--r--tests/test_asyncio/test_cwe_404.py146
5 files changed, 226 insertions, 75 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index 5c0b546..f5d8b01 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -475,24 +475,32 @@ class Redis(
):
raise error
- # COMMAND EXECUTION AND PROTOCOL PARSING
- async def execute_command(self, *args, **options):
- """Execute a command and return a parsed response"""
- await self.initialize()
- pool = self.connection_pool
- command_name = args[0]
- conn = self.connection or await pool.get_connection(command_name, **options)
-
+ async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
- conn, command_name, *args, **options
+ conn, args[0], *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
+ except asyncio.CancelledError:
+ await conn.disconnect(nowait=True)
+ raise
finally:
if not self.connection:
- await pool.release(conn)
+ await self.connection_pool.release(conn)
+
+ # COMMAND EXECUTION AND PROTOCOL PARSING
+ async def execute_command(self, *args, **options):
+ """Execute a command and return a parsed response"""
+ await self.initialize()
+ pool = self.connection_pool
+ command_name = args[0]
+ conn = self.connection or await pool.get_connection(command_name, **options)
+
+ return await asyncio.shield(
+ self._try_send_command_parse_response(conn, *args, **options)
+ )
async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
@@ -726,10 +734,18 @@ class PubSub:
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()
+
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()
+ async def _try_execute(self, conn, command, *arg, **kwargs):
+ try:
+ return await command(*arg, **kwargs)
+ except asyncio.CancelledError:
+ await conn.disconnect()
+ raise
+
async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
@@ -738,9 +754,11 @@ class PubSub:
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
- return await conn.retry.call_with_retry(
- lambda: command(*args, **kwargs),
- lambda error: self._disconnect_raise_connect(conn, error),
+ return await asyncio.shield(
+ conn.retry.call_with_retry(
+ lambda: self._try_execute(conn, command, *args, **kwargs),
+ lambda error: self._disconnect_raise_connect(conn, error),
+ )
)
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -1140,6 +1158,18 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
await self.reset()
raise
+ async def _try_send_command_parse_response(self, conn, *args, **options):
+ try:
+ return await conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(
+ conn, args[0], *args, **options
+ ),
+ lambda error: self._disconnect_reset_raise(conn, error),
+ )
+ except asyncio.CancelledError:
+ await conn.disconnect()
+ raise
+
async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
@@ -1155,13 +1185,13 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
command_name, self.shard_hint
)
self.connection = conn
-
- return await conn.retry.call_with_retry(
- lambda: self._send_command_parse_response(
- conn, command_name, *args, **options
- ),
- lambda error: self._disconnect_reset_raise(conn, error),
- )
+ try:
+ return await asyncio.shield(
+ self._try_send_command_parse_response(conn, *args, **options)
+ )
+ except asyncio.CancelledError:
+ await conn.disconnect()
+ raise
def pipeline_execute_command(self, *args, **options):
"""
@@ -1328,6 +1358,19 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
await self.reset()
raise
+ async def _try_execute(self, conn, execute, stack, raise_on_error):
+ try:
+ return await conn.retry.call_with_retry(
+ lambda: execute(conn, stack, raise_on_error),
+ lambda error: self._disconnect_raise_reset(conn, error),
+ )
+ except asyncio.CancelledError:
+ # not supposed to be possible, yet here we are
+ await conn.disconnect(nowait=True)
+ raise
+ finally:
+ await self.reset()
+
async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
@@ -1350,15 +1393,10 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
try:
return await asyncio.shield(
- conn.retry.call_with_retry(
- lambda: execute(conn, stack, raise_on_error),
- lambda error: self._disconnect_raise_reset(conn, error),
- )
+ self._try_execute(conn, execute, stack, raise_on_error)
)
- except asyncio.CancelledError:
- # not supposed to be possible, yet here we are
- await conn.disconnect(nowait=True)
- raise
+ except RuntimeError:
+ await self.reset()
finally:
await self.reset()
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index 8dfb1cb..50f1b6b 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -893,6 +893,19 @@ class ClusterNode:
finally:
self._free.append(connection)
+ async def _try_parse_response(self, cmd, connection, ret):
+ try:
+ cmd.result = await asyncio.shield(
+ self.parse_response(connection, cmd.args[0], **cmd.kwargs)
+ )
+ except asyncio.CancelledError:
+ await connection.disconnect(nowait=True)
+ raise
+ except Exception as e:
+ cmd.result = e
+ ret = True
+ return ret
+
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
@@ -905,13 +918,7 @@ class ClusterNode:
# Read responses
ret = False
for cmd in commands:
- try:
- cmd.result = await self.parse_response(
- connection, cmd.args[0], **cmd.kwargs
- )
- except Exception as e:
- cmd.result = e
- ret = True
+ ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
# Release connection
self._free.append(connection)
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 2e44cdd..d6e01f7 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -333,23 +333,6 @@ class TestRedisClusterObj:
called_count += 1
assert called_count == 1
- async def test_asynckills(self, r) -> None:
-
- await r.set("foo", "foo")
- await r.set("bar", "bar")
-
- t = asyncio.create_task(r.get("foo"))
- await asyncio.sleep(1)
- t.cancel()
- try:
- await t
- except asyncio.CancelledError:
- pytest.fail("connection is left open with unread response")
-
- assert await r.get("bar") == b"bar"
- assert await r.ping()
- assert await r.get("foo") == b"foo"
-
async def test_execute_command_default_node(self, r: RedisCluster) -> None:
"""
Test command execution without node flag is being executed on the
diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py
index c414ee0..f6259ad 100644
--- a/tests/test_asyncio/test_connection.py
+++ b/tests/test_asyncio/test_connection.py
@@ -28,29 +28,6 @@ async def test_invalid_response(create_redis):
assert str(cm.value) == f"Protocol Error: {raw!r}"
-@pytest.mark.onlynoncluster
-async def test_asynckills():
- from redis.asyncio.client import Redis
-
- for b in [True, False]:
- r = Redis(single_connection_client=b)
-
- await r.set("foo", "foo")
- await r.set("bar", "bar")
-
- t = asyncio.create_task(r.get("foo"))
- await asyncio.sleep(1)
- t.cancel()
- try:
- await t
- except asyncio.CancelledError:
- pytest.fail("connection left open with unread response")
-
- assert await r.get("bar") == b"bar"
- assert await r.ping()
- assert await r.get("foo") == b"foo"
-
-
@skip_if_server_version_lt("4.0.0")
@pytest.mark.redismod
@pytest.mark.onlynoncluster
diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py
new file mode 100644
index 0000000..6683440
--- /dev/null
+++ b/tests/test_asyncio/test_cwe_404.py
@@ -0,0 +1,146 @@
+import asyncio
+import sys
+
+import pytest
+
+from redis.asyncio import Redis
+from redis.asyncio.cluster import RedisCluster
+
+
+async def pipe(
+ reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
+):
+ while True:
+ data = await reader.read(1000)
+ if not data:
+ break
+ await asyncio.sleep(delay)
+ writer.write(data)
+ await writer.drain()
+
+
+class DelayProxy:
+ def __init__(self, addr, redis_addr, delay: float):
+ self.addr = addr
+ self.redis_addr = redis_addr
+ self.delay = delay
+
+ async def start(self):
+ self.server = await asyncio.start_server(self.handle, *self.addr)
+ self.ROUTINE = asyncio.create_task(self.server.serve_forever())
+
+ async def handle(self, reader, writer):
+ # establish connection to redis
+ redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
+ pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
+ pipe2 = asyncio.create_task(
+ pipe(redis_reader, writer, self.delay, "from redis:")
+ )
+ await asyncio.gather(pipe1, pipe2)
+
+ async def stop(self):
+ # clean up enough so that we can reuse the looper
+ self.ROUTINE.cancel()
+ loop = self.server.get_loop()
+ await loop.shutdown_asyncgens()
+
+
+@pytest.mark.onlynoncluster
+@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
+async def test_standalone(delay):
+
+ # create a tcp socket proxy that relays data to Redis and back,
+ # inserting 0.1 seconds of delay
+ dp = DelayProxy(
+ addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2
+ )
+ await dp.start()
+
+ for b in [True, False]:
+ # note that we connect to proxy, rather than to Redis directly
+ async with Redis(host="localhost", port=5380, single_connection_client=b) as r:
+
+ await r.set("foo", "foo")
+ await r.set("bar", "bar")
+
+ t = asyncio.create_task(r.get("foo"))
+ await asyncio.sleep(delay)
+ t.cancel()
+ try:
+ await t
+ sys.stderr.write("try again, we did not cancel the task in time\n")
+ except asyncio.CancelledError:
+ sys.stderr.write(
+ "canceled task, connection is left open with unread response\n"
+ )
+
+ assert await r.get("bar") == b"bar"
+ assert await r.ping()
+ assert await r.get("foo") == b"foo"
+
+ await dp.stop()
+
+
+@pytest.mark.onlynoncluster
+@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
+async def test_standalone_pipeline(delay):
+ dp = DelayProxy(
+ addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2
+ )
+ await dp.start()
+ async with Redis(host="localhost", port=5380) as r:
+ await r.set("foo", "foo")
+ await r.set("bar", "bar")
+
+ pipe = r.pipeline()
+
+ pipe2 = r.pipeline()
+ pipe2.get("bar")
+ pipe2.ping()
+ pipe2.get("foo")
+
+ t = asyncio.create_task(pipe.get("foo").execute())
+ await asyncio.sleep(delay)
+ t.cancel()
+
+ pipe.get("bar")
+ pipe.ping()
+ pipe.get("foo")
+ pipe.reset()
+
+ assert await pipe.execute() is None
+
+ # validating that the pipeline can be used as it could previously
+ pipe.get("bar")
+ pipe.ping()
+ pipe.get("foo")
+ assert await pipe.execute() == [b"bar", True, b"foo"]
+ assert await pipe2.execute() == [b"bar", True, b"foo"]
+
+ await dp.stop()
+
+
+@pytest.mark.onlycluster
+async def test_cluster(request):
+
+ dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1)
+ await dp.start()
+
+ r = RedisCluster.from_url("redis://localhost:5381")
+ await r.initialize()
+ await r.set("foo", "foo")
+ await r.set("bar", "bar")
+
+ t = asyncio.create_task(r.get("foo"))
+ await asyncio.sleep(0.050)
+ t.cancel()
+ try:
+ await t
+ except asyncio.CancelledError:
+ pytest.fail("connection is left open with unread response")
+
+ assert await r.get("bar") == b"bar"
+ assert await r.ping()
+ assert await r.get("foo") == b"foo"
+
+ await dp.stop()