summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Lemoine <eric.lemoine@getalma.eu>2022-06-19 03:56:53 +0200
committerGitHub <noreply@github.com>2022-06-19 04:56:53 +0300
commitbea72995fd39b01e2f0a1682b16b6c7690933f36 (patch)
tree477b9093e9664a13add96681a2012ded0ffbc798
parent33702983b8b0a55d29189babb631ea108ee8404f (diff)
downloadredis-py-bea72995fd39b01e2f0a1682b16b6c7690933f36.tar.gz
Fix retries in async mode (#2180)
* Avoid mutating a global retry_on_error list * Make retries config consistent in sync and async * Fix async retries * Add new TestConnectionConstructorWithRetry tests
-rw-r--r--redis/asyncio/client.py17
-rw-r--r--redis/asyncio/connection.py17
-rw-r--r--redis/asyncio/retry.py8
-rwxr-xr-xredis/client.py4
-rwxr-xr-xredis/connection.py8
-rw-r--r--tests/test_asyncio/test_retry.py38
6 files changed, 83 insertions, 9 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index 6db5489..3d59016 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -158,6 +158,7 @@ class Redis(
encoding_errors: str = "strict",
decode_responses: bool = False,
retry_on_timeout: bool = False,
+ retry_on_error: Optional[list] = None,
ssl: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
@@ -176,8 +177,10 @@ class Redis(
):
"""
Initialize a new Redis client.
- To specify a retry policy, first set `retry_on_timeout` to `True`
- then set `retry` to a valid `Retry` object
+ To specify a retry policy for specific errors, first set
+ `retry_on_error` to a list of the error/s to retry on, then set
+ `retry` to a valid `Retry` object.
+ To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
kwargs: Dict[str, Any]
# auto_close_connection_pool only has an effect if connection_pool is
@@ -188,6 +191,10 @@ class Redis(
auto_close_connection_pool if connection_pool is None else False
)
if not connection_pool:
+ if not retry_on_error:
+ retry_on_error = []
+ if retry_on_timeout is True:
+ retry_on_error.append(TimeoutError)
kwargs = {
"db": db,
"username": username,
@@ -197,6 +204,7 @@ class Redis(
"encoding_errors": encoding_errors,
"decode_responses": decode_responses,
"retry_on_timeout": retry_on_timeout,
+ "retry_on_error": retry_on_error,
"retry": copy.deepcopy(retry),
"max_connections": max_connections,
"health_check_interval": health_check_interval,
@@ -461,7 +469,10 @@ class Redis(
is not a TimeoutError
"""
await conn.disconnect()
- if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ if (
+ conn.retry_on_error is None
+ or isinstance(error, tuple(conn.retry_on_error)) is False
+ ):
raise error
# COMMAND EXECUTION AND PROTOCOL PARSING
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 38465fc..35536fc 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -578,6 +578,7 @@ class Connection:
"socket_type",
"redis_connect_func",
"retry_on_timeout",
+ "retry_on_error",
"health_check_interval",
"next_health_check",
"last_active_at",
@@ -606,6 +607,7 @@ class Connection:
socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
socket_type: int = 0,
retry_on_timeout: bool = False,
+ retry_on_error: Union[list, _Sentinel] = SENTINEL,
encoding: str = "utf-8",
encoding_errors: str = "strict",
decode_responses: bool = False,
@@ -631,12 +633,19 @@ class Connection:
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
+ if retry_on_error is SENTINEL:
+ retry_on_error = []
if retry_on_timeout:
+ retry_on_error.append(TimeoutError)
+ self.retry_on_error = retry_on_error
+ if retry_on_error:
if not retry:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
self.retry = copy.deepcopy(retry)
+ # Update the retry's supported errors with the specified errors
+ self.retry.update_supported_errors(retry_on_error)
else:
self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
@@ -1169,6 +1178,7 @@ class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init]
encoding_errors: str = "strict",
decode_responses: bool = False,
retry_on_timeout: bool = False,
+ retry_on_error: Union[list, _Sentinel] = SENTINEL,
parser_class: Type[BaseParser] = DefaultParser,
socket_read_size: int = 65536,
health_check_interval: float = 0.0,
@@ -1190,12 +1200,19 @@ class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init]
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
self.retry_on_timeout = retry_on_timeout
+ if retry_on_error is SENTINEL:
+ retry_on_error = []
if retry_on_timeout:
+ retry_on_error.append(TimeoutError)
+ self.retry_on_error = retry_on_error
+ if retry_on_error:
if retry is None:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
self.retry = copy.deepcopy(retry)
+ # Update the retry's supported errors with the specified errors
+ self.retry.update_supported_errors(retry_on_error)
else:
self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py
index 0934ad0..7c5e3b0 100644
--- a/redis/asyncio/retry.py
+++ b/redis/asyncio/retry.py
@@ -35,6 +35,14 @@ class Retry:
self._retries = retries
self._supported_errors = supported_errors
+ def update_supported_errors(self, specified_errors: list):
+ """
+ Updates the supported errors with the specified error types
+ """
+ self._supported_errors = tuple(
+ set(self._supported_errors + tuple(specified_errors))
+ )
+
async def call_with_retry(
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
) -> T:
diff --git a/redis/client.py b/redis/client.py
index fcc2758..86061d5 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -914,7 +914,7 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands):
errors=None,
decode_responses=False,
retry_on_timeout=False,
- retry_on_error=[],
+ retry_on_error=None,
ssl=False,
ssl_keyfile=None,
ssl_certfile=None,
@@ -958,6 +958,8 @@ class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands):
)
)
encoding_errors = errors
+ if not retry_on_error:
+ retry_on_error = []
if retry_on_timeout is True:
retry_on_error.append(TimeoutError)
kwargs = {
diff --git a/redis/connection.py b/redis/connection.py
index 1bc2ae1..3438baf 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -515,7 +515,7 @@ class Connection:
socket_keepalive_options=None,
socket_type=0,
retry_on_timeout=False,
- retry_on_error=[],
+ retry_on_error=SENTINEL,
encoding="utf-8",
encoding_errors="strict",
decode_responses=False,
@@ -547,6 +547,8 @@ class Connection:
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
+ if retry_on_error is SENTINEL:
+ retry_on_error = []
if retry_on_timeout:
# Add TimeoutError to the errors list to retry on
retry_on_error.append(TimeoutError)
@@ -1065,7 +1067,7 @@ class UnixDomainSocketConnection(Connection):
encoding_errors="strict",
decode_responses=False,
retry_on_timeout=False,
- retry_on_error=[],
+ retry_on_error=SENTINEL,
parser_class=DefaultParser,
socket_read_size=65536,
health_check_interval=0,
@@ -1088,6 +1090,8 @@ class UnixDomainSocketConnection(Connection):
self.password = password
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
+ if retry_on_error is SENTINEL:
+ retry_on_error = []
if retry_on_timeout:
# Add TimeoutError to the errors list to retry on
retry_on_error.append(TimeoutError)
diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py
index d696d72..38e353b 100644
--- a/tests/test_asyncio/test_retry.py
+++ b/tests/test_asyncio/test_retry.py
@@ -3,7 +3,7 @@ import pytest
from redis.asyncio.connection import Connection, UnixDomainSocketConnection
from redis.asyncio.retry import Retry
from redis.backoff import AbstractBackoff, NoBackoff
-from redis.exceptions import ConnectionError
+from redis.exceptions import ConnectionError, TimeoutError
class BackoffMock(AbstractBackoff):
@@ -22,9 +22,28 @@ class BackoffMock(AbstractBackoff):
class TestConnectionConstructorWithRetry:
"Test that the Connection constructors properly handles Retry objects"
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_on_error_set(self, Class):
+ class CustomError(Exception):
+ pass
+
+ retry_on_error = [ConnectionError, TimeoutError, CustomError]
+ c = Class(retry_on_error=retry_on_error)
+ assert c.retry_on_error == retry_on_error
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == 1
+ assert set(c.retry._supported_errors) == set(retry_on_error)
+
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_on_error_not_set(self, Class):
+ c = Class()
+ assert c.retry_on_error == []
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == 0
+
@pytest.mark.parametrize("retry_on_timeout", [False, True])
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
- def test_retry_on_timeout_boolean(self, Class, retry_on_timeout):
+ def test_retry_on_timeout(self, Class, retry_on_timeout):
c = Class(retry_on_timeout=retry_on_timeout)
assert c.retry_on_timeout == retry_on_timeout
assert isinstance(c.retry, Retry)
@@ -32,13 +51,26 @@ class TestConnectionConstructorWithRetry:
@pytest.mark.parametrize("retries", range(10))
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
- def test_retry_on_timeout_retry(self, Class, retries: int):
+ def test_retry_with_retry_on_timeout(self, Class, retries: int):
retry_on_timeout = retries > 0
c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries))
assert c.retry_on_timeout == retry_on_timeout
assert isinstance(c.retry, Retry)
assert c.retry._retries == retries
+ @pytest.mark.parametrize("retries", range(10))
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_with_retry_on_error(self, Class, retries: int):
+ class CustomError(Exception):
+ pass
+
+ retry_on_error = [ConnectionError, TimeoutError, CustomError]
+ c = Class(retry_on_error=retry_on_error, retry=Retry(NoBackoff(), retries))
+ assert c.retry_on_error == retry_on_error
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == retries
+ assert set(c.retry._supported_errors) == set(retry_on_error)
+
class TestRetry:
"Test that Retry calls backoff and retries the expected number of times"