summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChayim <chayim@users.noreply.github.com>2023-03-29 12:01:45 +0300
committerGitHub <noreply@github.com>2023-03-29 12:01:45 +0300
commit5acbde355058ab7d9c2f95bcef3993ab4134e342 (patch)
treebc90887cf2fc77d870254b5618d32a1a701c9186
parent6d886d7c7b405c0fe5d59ca192c87b438bf080f5 (diff)
downloadredis-py-5acbde355058ab7d9c2f95bcef3993ab4134e342.tar.gz
Fixing cancelled async futures (#2666)
Co-authored-by: James R T <jamestiotio@gmail.com> Co-authored-by: dvora-h <dvora.heller@redis.com>
-rw-r--r--.github/workflows/integration.yaml2
-rw-r--r--redis/asyncio/client.py99
-rw-r--r--redis/asyncio/cluster.py21
-rw-r--r--tests/test_asyncio/test_cluster.py17
-rw-r--r--tests/test_asyncio/test_connection.py21
-rw-r--r--tests/test_asyncio/test_cwe_404.py146
-rw-r--r--tests/test_asyncio/test_pubsub.py3
7 files changed, 234 insertions, 75 deletions
diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml
index 0f9db8f..f49a4fc 100644
--- a/.github/workflows/integration.yaml
+++ b/.github/workflows/integration.yaml
@@ -51,6 +51,7 @@ jobs:
timeout-minutes: 30
strategy:
max-parallel: 15
+ fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
test-type: ['standalone', 'cluster']
@@ -108,6 +109,7 @@ jobs:
name: Install package from commit hash
runs-on: ubuntu-latest
strategy:
+ fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
steps:
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index 5de2ff9..7986b11 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -500,28 +500,37 @@ class Redis(
):
raise error
- # COMMAND EXECUTION AND PROTOCOL PARSING
- async def execute_command(self, *args, **options):
- """Execute a command and return a parsed response"""
- await self.initialize()
- pool = self.connection_pool
- command_name = args[0]
- conn = self.connection or await pool.get_connection(command_name, **options)
-
- if self.single_connection_client:
- await self._single_conn_lock.acquire()
+ async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
- conn, command_name, *args, **options
+ conn, args[0], *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
+ except asyncio.CancelledError:
+ await conn.disconnect(nowait=True)
+ raise
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
- await pool.release(conn)
+ await self.connection_pool.release(conn)
+
+ # COMMAND EXECUTION AND PROTOCOL PARSING
+ async def execute_command(self, *args, **options):
+ """Execute a command and return a parsed response"""
+ await self.initialize()
+ pool = self.connection_pool
+ command_name = args[0]
+ conn = self.connection or await pool.get_connection(command_name, **options)
+
+ if self.single_connection_client:
+ await self._single_conn_lock.acquire()
+
+ return await asyncio.shield(
+ self._try_send_command_parse_response(conn, *args, **options)
+ )
async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
@@ -765,10 +774,18 @@ class PubSub:
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()
+
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()
+ async def _try_execute(self, conn, command, *arg, **kwargs):
+ try:
+ return await command(*arg, **kwargs)
+ except asyncio.CancelledError:
+ await conn.disconnect()
+ raise
+
async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
@@ -777,9 +794,11 @@ class PubSub:
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
- return await conn.retry.call_with_retry(
- lambda: command(*args, **kwargs),
- lambda error: self._disconnect_raise_connect(conn, error),
+ return await asyncio.shield(
+ conn.retry.call_with_retry(
+ lambda: self._try_execute(conn, command, *args, **kwargs),
+ lambda error: self._disconnect_raise_connect(conn, error),
+ )
)
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -1181,6 +1200,18 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
await self.reset()
raise
+ async def _try_send_command_parse_response(self, conn, *args, **options):
+ try:
+ return await conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(
+ conn, args[0], *args, **options
+ ),
+ lambda error: self._disconnect_reset_raise(conn, error),
+ )
+ except asyncio.CancelledError:
+ await conn.disconnect()
+ raise
+
async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
@@ -1196,13 +1227,13 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
command_name, self.shard_hint
)
self.connection = conn
-
- return await conn.retry.call_with_retry(
- lambda: self._send_command_parse_response(
- conn, command_name, *args, **options
- ),
- lambda error: self._disconnect_reset_raise(conn, error),
- )
+ try:
+ return await asyncio.shield(
+ self._try_send_command_parse_response(conn, *args, **options)
+ )
+ except asyncio.CancelledError:
+ await conn.disconnect()
+ raise
def pipeline_execute_command(self, *args, **options):
"""
@@ -1369,6 +1400,19 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
await self.reset()
raise
+ async def _try_execute(self, conn, execute, stack, raise_on_error):
+ try:
+ return await conn.retry.call_with_retry(
+ lambda: execute(conn, stack, raise_on_error),
+ lambda error: self._disconnect_raise_reset(conn, error),
+ )
+ except asyncio.CancelledError:
+ # not supposed to be possible, yet here we are
+ await conn.disconnect(nowait=True)
+ raise
+ finally:
+ await self.reset()
+
async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
@@ -1391,15 +1435,10 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
try:
return await asyncio.shield(
- conn.retry.call_with_retry(
- lambda: execute(conn, stack, raise_on_error),
- lambda error: self._disconnect_raise_reset(conn, error),
- )
+ self._try_execute(conn, execute, stack, raise_on_error)
)
- except asyncio.CancelledError:
- # not supposed to be possible, yet here we are
- await conn.disconnect(nowait=True)
- raise
+ except RuntimeError:
+ await self.reset()
finally:
await self.reset()
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index 569a076..a4a9561 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -1016,6 +1016,19 @@ class ClusterNode:
finally:
self._free.append(connection)
+ async def _try_parse_response(self, cmd, connection, ret):
+ try:
+ cmd.result = await asyncio.shield(
+ self.parse_response(connection, cmd.args[0], **cmd.kwargs)
+ )
+ except asyncio.CancelledError:
+ await connection.disconnect(nowait=True)
+ raise
+ except Exception as e:
+ cmd.result = e
+ ret = True
+ return ret
+
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
@@ -1028,13 +1041,7 @@ class ClusterNode:
# Read responses
ret = False
for cmd in commands:
- try:
- cmd.result = await self.parse_response(
- connection, cmd.args[0], **cmd.kwargs
- )
- except Exception as e:
- cmd.result = e
- ret = True
+ ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
# Release connection
self._free.append(connection)
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 0857c05..13e5e26 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -340,23 +340,6 @@ class TestRedisClusterObj:
rc = RedisCluster.from_url("rediss://localhost:16379")
assert rc.connection_kwargs["connection_class"] is SSLConnection
- async def test_asynckills(self, r) -> None:
-
- await r.set("foo", "foo")
- await r.set("bar", "bar")
-
- t = asyncio.create_task(r.get("foo"))
- await asyncio.sleep(1)
- t.cancel()
- try:
- await t
- except asyncio.CancelledError:
- pytest.fail("connection is left open with unread response")
-
- assert await r.get("bar") == b"bar"
- assert await r.ping()
- assert await r.get("foo") == b"foo"
-
async def test_max_connections(
self, create_redis: Callable[..., RedisCluster]
) -> None:
diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py
index d3b6285..e2d77fc 100644
--- a/tests/test_asyncio/test_connection.py
+++ b/tests/test_asyncio/test_connection.py
@@ -44,27 +44,6 @@ async def test_invalid_response(create_redis):
await r.connection.disconnect()
-async def test_asynckills():
-
- for b in [True, False]:
- r = Redis(single_connection_client=b)
-
- await r.set("foo", "foo")
- await r.set("bar", "bar")
-
- t = asyncio.create_task(r.get("foo"))
- await asyncio.sleep(1)
- t.cancel()
- try:
- await t
- except asyncio.CancelledError:
- pytest.fail("connection left open with unread response")
-
- assert await r.get("bar") == b"bar"
- assert await r.ping()
- assert await r.get("foo") == b"foo"
-
-
@pytest.mark.onlynoncluster
async def test_single_connection():
"""Test that concurrent requests on a single client are synchronised."""
diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py
new file mode 100644
index 0000000..6683440
--- /dev/null
+++ b/tests/test_asyncio/test_cwe_404.py
@@ -0,0 +1,146 @@
+import asyncio
+import sys
+
+import pytest
+
+from redis.asyncio import Redis
+from redis.asyncio.cluster import RedisCluster
+
+
+async def pipe(
+ reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
+):
+ while True:
+ data = await reader.read(1000)
+ if not data:
+ break
+ await asyncio.sleep(delay)
+ writer.write(data)
+ await writer.drain()
+
+
+class DelayProxy:
+ def __init__(self, addr, redis_addr, delay: float):
+ self.addr = addr
+ self.redis_addr = redis_addr
+ self.delay = delay
+
+ async def start(self):
+ self.server = await asyncio.start_server(self.handle, *self.addr)
+ self.ROUTINE = asyncio.create_task(self.server.serve_forever())
+
+ async def handle(self, reader, writer):
+ # establish connection to redis
+ redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
+ pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
+ pipe2 = asyncio.create_task(
+ pipe(redis_reader, writer, self.delay, "from redis:")
+ )
+ await asyncio.gather(pipe1, pipe2)
+
+ async def stop(self):
+ # clean up enough so that we can reuse the looper
+ self.ROUTINE.cancel()
+ loop = self.server.get_loop()
+ await loop.shutdown_asyncgens()
+
+
+@pytest.mark.onlynoncluster
+@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
+async def test_standalone(delay):
+
+ # create a tcp socket proxy that relays data to Redis and back,
+ # inserting 0.1 seconds of delay
+ dp = DelayProxy(
+ addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2
+ )
+ await dp.start()
+
+ for b in [True, False]:
+ # note that we connect to proxy, rather than to Redis directly
+ async with Redis(host="localhost", port=5380, single_connection_client=b) as r:
+
+ await r.set("foo", "foo")
+ await r.set("bar", "bar")
+
+ t = asyncio.create_task(r.get("foo"))
+ await asyncio.sleep(delay)
+ t.cancel()
+ try:
+ await t
+ sys.stderr.write("try again, we did not cancel the task in time\n")
+ except asyncio.CancelledError:
+ sys.stderr.write(
+ "canceled task, connection is left open with unread response\n"
+ )
+
+ assert await r.get("bar") == b"bar"
+ assert await r.ping()
+ assert await r.get("foo") == b"foo"
+
+ await dp.stop()
+
+
+@pytest.mark.onlynoncluster
+@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
+async def test_standalone_pipeline(delay):
+ dp = DelayProxy(
+ addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2
+ )
+ await dp.start()
+ async with Redis(host="localhost", port=5380) as r:
+ await r.set("foo", "foo")
+ await r.set("bar", "bar")
+
+ pipe = r.pipeline()
+
+ pipe2 = r.pipeline()
+ pipe2.get("bar")
+ pipe2.ping()
+ pipe2.get("foo")
+
+ t = asyncio.create_task(pipe.get("foo").execute())
+ await asyncio.sleep(delay)
+ t.cancel()
+
+ pipe.get("bar")
+ pipe.ping()
+ pipe.get("foo")
+ pipe.reset()
+
+ assert await pipe.execute() is None
+
+ # validating that the pipeline can be used as it could previously
+ pipe.get("bar")
+ pipe.ping()
+ pipe.get("foo")
+ assert await pipe.execute() == [b"bar", True, b"foo"]
+ assert await pipe2.execute() == [b"bar", True, b"foo"]
+
+ await dp.stop()
+
+
+@pytest.mark.onlycluster
+async def test_cluster(request):
+
+ dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1)
+ await dp.start()
+
+ r = RedisCluster.from_url("redis://localhost:5381")
+ await r.initialize()
+ await r.set("foo", "foo")
+ await r.set("bar", "bar")
+
+ t = asyncio.create_task(r.get("foo"))
+ await asyncio.sleep(0.050)
+ t.cancel()
+ try:
+ await t
+ except asyncio.CancelledError:
+ pytest.fail("connection is left open with unread response")
+
+ assert await r.get("bar") == b"bar"
+ assert await r.ping()
+ assert await r.get("foo") == b"foo"
+
+ await dp.stop()
diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py
index 8f3817a..ba70782 100644
--- a/tests/test_asyncio/test_pubsub.py
+++ b/tests/test_asyncio/test_pubsub.py
@@ -973,6 +973,9 @@ class TestBaseException:
# the timeout on the read should not cause disconnect
assert pubsub.connection.is_connected
+ @pytest.mark.skipif(
+ sys.version_info < (3, 8), reason="requires python 3.8 or higher"
+ )
async def test_base_exception(self, r: redis.Redis):
"""
Manually trigger a BaseException inside the parser's .read_response method