summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristján Valur Jónsson <sweskman@gmail.com>2022-09-29 13:37:55 +0000
committerGitHub <noreply@github.com>2022-09-29 16:37:55 +0300
commitb0883b791f95a595fae70bcedf3ad0f73c00e258 (patch)
tree6bfeb975aa6b9ec6325e90246230d01172beb1a6
parentcdbc662adcd303d2525f3ace70531aa37a755652 (diff)
downloadredis-py-b0883b791f95a595fae70bcedf3ad0f73c00e258.tar.gz
Simplify async timeouts and allowing `timeout=None` in `PubSub.get_message()` to wait forever (#2295)v4.4.0rc2
* Avoid an extra "can_read" call and use timeout directly. * Remove low-level read timeouts from the Parser, now handled in the Connection * Allow pubsub.get_message(time=None) to block. * update Changes * increase test timeout for robustness * expand with statement to avoid invoking null context managers. remove nullcontext * Remove unused import
-rw-r--r--CHANGES1
-rw-r--r--redis/asyncio/client.py22
-rw-r--r--redis/asyncio/connection.py138
-rwxr-xr-xredis/client.py6
-rw-r--r--tests/test_asyncio/test_pubsub.py2
5 files changed, 57 insertions, 112 deletions
diff --git a/CHANGES b/CHANGES
index a5b5029..2ced3d8 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,4 @@
+ * Allow `timeout=None` in `PubSub.get_message()` to wait forever
* add `nowait` flag to `asyncio.Connection.disconnect()`
* Update README.md links
* Fix timezone handling for datetime to unixtime conversions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index c13054b..0e40ed7 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -24,8 +24,6 @@ from typing import (
cast,
)
-import async_timeout
-
from redis.asyncio.connection import (
Connection,
ConnectionPool,
@@ -759,18 +757,8 @@ class PubSub:
if not conn.is_connected:
await conn.connect()
- if not block:
-
- async def read_with_timeout():
- try:
- async with async_timeout.timeout(timeout):
- return await conn.read_response()
- except asyncio.TimeoutError:
- return None
-
- response = await self._execute(conn, read_with_timeout)
- else:
- response = await self._execute(conn, conn.read_response)
+ read_timeout = None if block else timeout
+ response = await self._execute(conn, conn.read_response, timeout=read_timeout)
if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
@@ -882,16 +870,16 @@ class PubSub:
yield response
async def get_message(
- self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
+ self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
- number.
+ number or None to wait indefinitely.
"""
- response = await self.parse_response(block=False, timeout=timeout)
+ response = await self.parse_response(block=(timeout is None), timeout=timeout)
if response:
return await self.handle_message(response, ignore_subscribe_messages)
return None
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 53b41af..c8834c9 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -1,7 +1,6 @@
import asyncio
import copy
import enum
-import errno
import inspect
import io
import os
@@ -55,16 +54,6 @@ hiredis = None
if HIREDIS_AVAILABLE:
import hiredis
-NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
- BlockingIOError: errno.EWOULDBLOCK,
- ssl.SSLWantReadError: 2,
- ssl.SSLWantWriteError: 2,
- ssl.SSLError: 2,
-}
-
-NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
-
-
SYM_STAR = b"*"
SYM_DOLLAR = b"$"
SYM_CRLF = b"\r\n"
@@ -229,11 +218,9 @@ class SocketBuffer:
self,
stream_reader: asyncio.StreamReader,
socket_read_size: int,
- socket_timeout: Optional[float],
):
self._stream: Optional[asyncio.StreamReader] = stream_reader
self.socket_read_size = socket_read_size
- self.socket_timeout = socket_timeout
self._buffer: Optional[io.BytesIO] = io.BytesIO()
# number of bytes written to the buffer from the socket
self.bytes_written = 0
@@ -244,52 +231,35 @@ class SocketBuffer:
def length(self):
return self.bytes_written - self.bytes_read
- async def _read_from_socket(
- self,
- length: Optional[int] = None,
- timeout: Union[float, None, _Sentinel] = SENTINEL,
- raise_on_timeout: bool = True,
- ) -> bool:
+ async def _read_from_socket(self, length: Optional[int] = None) -> bool:
buf = self._buffer
if buf is None or self._stream is None:
raise RedisError("Buffer is closed.")
buf.seek(self.bytes_written)
marker = 0
- timeout = timeout if timeout is not SENTINEL else self.socket_timeout
- try:
- while True:
- async with async_timeout.timeout(timeout):
- data = await self._stream.read(self.socket_read_size)
- # an empty string indicates the server shutdown the socket
- if isinstance(data, bytes) and len(data) == 0:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- buf.write(data)
- data_length = len(data)
- self.bytes_written += data_length
- marker += data_length
-
- if length is not None and length > marker:
- continue
- return True
- except (socket.timeout, asyncio.TimeoutError):
- if raise_on_timeout:
- raise TimeoutError("Timeout reading from socket")
- return False
- except NONBLOCKING_EXCEPTIONS as ex:
- # if we're in nonblocking mode and the recv raises a
- # blocking error, simply return False indicating that
- # there's no data to be read. otherwise raise the
- # original exception.
- allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
- if not raise_on_timeout and ex.errno == allowed:
- return False
- raise ConnectionError(f"Error while reading from socket: {ex.args}")
+ while True:
+ data = await self._stream.read(self.socket_read_size)
+ # an empty string indicates the server shutdown the socket
+ if isinstance(data, bytes) and len(data) == 0:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ buf.write(data)
+ data_length = len(data)
+ self.bytes_written += data_length
+ marker += data_length
+
+ if length is not None and length > marker:
+ continue
+ return True
async def can_read_destructive(self) -> bool:
- return bool(self.length) or await self._read_from_socket(
- timeout=0, raise_on_timeout=False
- )
+ if self.length:
+ return True
+ try:
+ async with async_timeout.timeout(0):
+ return await self._read_from_socket()
+ except asyncio.TimeoutError:
+ return False
async def read(self, length: int) -> bytes:
length = length + 2 # make sure to read the \r\n terminator
@@ -372,9 +342,7 @@ class PythonParser(BaseParser):
if self._stream is None:
raise RedisError("Buffer is closed.")
- self._buffer = SocketBuffer(
- self._stream, self._read_size, connection.socket_timeout
- )
+ self._buffer = SocketBuffer(self._stream, self._read_size)
self.encoder = connection.encoder
def on_disconnect(self):
@@ -444,14 +412,13 @@ class PythonParser(BaseParser):
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""
- __slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout")
+ __slots__ = BaseParser.__slots__ + ("_reader",)
def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader: Optional[hiredis.Reader] = None
- self._socket_timeout: Optional[float] = None
def on_connect(self, connection: "Connection"):
self._stream = connection._reader
@@ -464,7 +431,6 @@ class HiredisParser(BaseParser):
kwargs["errors"] = connection.encoder.encoding_errors
self._reader = hiredis.Reader(**kwargs)
- self._socket_timeout = connection.socket_timeout
def on_disconnect(self):
self._stream = None
@@ -475,39 +441,20 @@ class HiredisParser(BaseParser):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._reader.gets():
return True
- return await self.read_from_socket(timeout=0, raise_on_timeout=False)
-
- async def read_from_socket(
- self,
- timeout: Union[float, None, _Sentinel] = SENTINEL,
- raise_on_timeout: bool = True,
- ):
- timeout = self._socket_timeout if timeout is SENTINEL else timeout
try:
- if timeout is None:
- buffer = await self._stream.read(self._read_size)
- else:
- async with async_timeout.timeout(timeout):
- buffer = await self._stream.read(self._read_size)
- if not buffer or not isinstance(buffer, bytes):
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
- self._reader.feed(buffer)
- # data was read from the socket and added to the buffer.
- # return True to indicate that data was read.
- return True
- except (socket.timeout, asyncio.TimeoutError):
- if raise_on_timeout:
- raise TimeoutError("Timeout reading from socket") from None
+ async with async_timeout.timeout(0):
+ return await self.read_from_socket()
+ except asyncio.TimeoutError:
return False
- except NONBLOCKING_EXCEPTIONS as ex:
- # if we're in nonblocking mode and the recv raises a
- # blocking error, simply return False indicating that
- # there's no data to be read. otherwise raise the
- # original exception.
- allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
- if not raise_on_timeout and ex.errno == allowed:
- return False
- raise ConnectionError(f"Error while reading from socket: {ex.args}")
+
+ async def read_from_socket(self):
+ buffer = await self._stream.read(self._read_size)
+ if not buffer or not isinstance(buffer, bytes):
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
+ self._reader.feed(buffer)
+ # data was read from the socket and added to the buffer.
+ # return True to indicate that data was read.
+ return True
async def read_response(
self, disable_decoding: bool = False
@@ -922,11 +869,16 @@ class Connection:
f"Error while reading from {self.host}:{self.port}: {e.args}"
)
- async def read_response(self, disable_decoding: bool = False):
+ async def read_response(
+ self,
+ disable_decoding: bool = False,
+ timeout: Optional[float] = None,
+ ):
"""Read the response from a previously sent command"""
+ read_timeout = timeout if timeout is not None else self.socket_timeout
try:
- if self.socket_timeout:
- async with async_timeout.timeout(self.socket_timeout):
+ if read_timeout is not None:
+ async with async_timeout.timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
@@ -935,6 +887,10 @@ class Connection:
disable_decoding=disable_decoding
)
except asyncio.TimeoutError:
+ if timeout is not None:
+ # user requested timeout, return None
+ return None
+ # it was a self.socket_timeout error.
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
diff --git a/redis/client.py b/redis/client.py
index 0662a99..75a0dac 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -1637,13 +1637,13 @@ class PubSub:
if response is not None:
yield response
- def get_message(self, ignore_subscribe_messages=False, timeout=0):
+ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
- number.
+ number, or None, to wait indefinitely.
"""
if not self.subscribed:
# Wait for subscription
@@ -1659,7 +1659,7 @@ class PubSub:
# so no messages are available
return None
- response = self.parse_response(block=False, timeout=timeout)
+ response = self.parse_response(block=(timeout is None), timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
return None
diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py
index 86584e4..6dedca9 100644
--- a/tests/test_asyncio/test_pubsub.py
+++ b/tests/test_asyncio/test_pubsub.py
@@ -29,7 +29,7 @@ def with_timeout(t):
return wrapper
-async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
+async def wait_for_message(pubsub, timeout=0.2, ignore_subscribe_messages=False):
now = asyncio.get_event_loop().time()
timeout = now + timeout
while now < timeout: