From b0883b791f95a595fae70bcedf3ad0f73c00e258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 29 Sep 2022 13:37:55 +0000 Subject: Simplify async timeouts and allowing `timeout=None` in `PubSub.get_message()` to wait forever (#2295) * Avoid an extra "can_read" call and use timeout directly. * Remove low-level read timeouts from the Parser, now handled in the Connection * Allow pubsub.get_message(time=None) to block. * update Changes * increase test timeout for robustness * expand with statement to avoid invoking null context managers. remove nullcontext * Remove unused import --- CHANGES | 1 + redis/asyncio/client.py | 22 ++---- redis/asyncio/connection.py | 138 +++++++++++++------------------------- redis/client.py | 6 +- tests/test_asyncio/test_pubsub.py | 2 +- 5 files changed, 57 insertions(+), 112 deletions(-) diff --git a/CHANGES b/CHANGES index a5b5029..2ced3d8 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Allow `timeout=None` in `PubSub.get_message()` to wait forever * add `nowait` flag to `asyncio.Connection.disconnect()` * Update README.md links * Fix timezone handling for datetime to unixtime conversions diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index c13054b..0e40ed7 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -24,8 +24,6 @@ from typing import ( cast, ) -import async_timeout - from redis.asyncio.connection import ( Connection, ConnectionPool, @@ -759,18 +757,8 @@ class PubSub: if not conn.is_connected: await conn.connect() - if not block: - - async def read_with_timeout(): - try: - async with async_timeout.timeout(timeout): - return await conn.read_response() - except asyncio.TimeoutError: - return None - - response = await self._execute(conn, read_with_timeout) - else: - response = await self._execute(conn, conn.read_response) + read_timeout = None if block else timeout + response = await self._execute(conn, conn.read_response, timeout=read_timeout) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it @@ -882,16 +870,16 @@ class PubSub: yield response async def get_message( - self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 ): """ Get the next message if one is available, otherwise None. If timeout is specified, the system will wait for `timeout` seconds before returning. Timeout should be specified as a floating point - number. + number or None to wait indefinitely. """ - response = await self.parse_response(block=False, timeout=timeout) + response = await self.parse_response(block=(timeout is None), timeout=timeout) if response: return await self.handle_message(response, ignore_subscribe_messages) return None diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 53b41af..c8834c9 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1,7 +1,6 @@ import asyncio import copy import enum -import errno import inspect import io import os @@ -55,16 +54,6 @@ hiredis = None if HIREDIS_AVAILABLE: import hiredis -NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { - BlockingIOError: errno.EWOULDBLOCK, - ssl.SSLWantReadError: 2, - ssl.SSLWantWriteError: 2, - ssl.SSLError: 2, -} - -NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) - - SYM_STAR = b"*" SYM_DOLLAR = b"$" SYM_CRLF = b"\r\n" @@ -229,11 +218,9 @@ class SocketBuffer: self, stream_reader: asyncio.StreamReader, socket_read_size: int, - socket_timeout: Optional[float], ): self._stream: Optional[asyncio.StreamReader] = stream_reader self.socket_read_size = socket_read_size - self.socket_timeout = socket_timeout self._buffer: Optional[io.BytesIO] = io.BytesIO() # number of bytes written to the buffer from the socket self.bytes_written = 0 @@ -244,52 +231,35 @@ class SocketBuffer: def length(self): return self.bytes_written - self.bytes_read - async def _read_from_socket( - self, - length: Optional[int] = None, - timeout: Union[float, None, _Sentinel] = SENTINEL, - raise_on_timeout: bool = True, - ) -> bool: + async def _read_from_socket(self, length: Optional[int] = None) -> bool: buf = self._buffer if buf is None or self._stream is None: raise RedisError("Buffer is closed.") buf.seek(self.bytes_written) marker = 0 - timeout = timeout if timeout is not SENTINEL else self.socket_timeout - try: - while True: - async with async_timeout.timeout(timeout): - data = await self._stream.read(self.socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - self.bytes_written += data_length - marker += data_length - - if length is not None and length > marker: - continue - return True - except (socket.timeout, asyncio.TimeoutError): - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") + while True: + data = await self._stream.read(self.socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + self.bytes_written += data_length + marker += data_length + + if length is not None and length > marker: + continue + return True async def can_read_destructive(self) -> bool: - return bool(self.length) or await self._read_from_socket( - timeout=0, raise_on_timeout=False - ) + if self.length: + return True + try: + async with async_timeout.timeout(0): + return await self._read_from_socket() + except asyncio.TimeoutError: + return False async def read(self, length: int) -> bytes: length = length + 2 # make sure to read the \r\n terminator @@ -372,9 +342,7 @@ class PythonParser(BaseParser): if self._stream is None: raise RedisError("Buffer is closed.") - self._buffer = SocketBuffer( - self._stream, self._read_size, connection.socket_timeout - ) + self._buffer = SocketBuffer(self._stream, self._read_size) self.encoder = connection.encoder def on_disconnect(self): @@ -444,14 +412,13 @@ class PythonParser(BaseParser): class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" - __slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout") + __slots__ = BaseParser.__slots__ + ("_reader",) def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) self._reader: Optional[hiredis.Reader] = None - self._socket_timeout: Optional[float] = None def on_connect(self, connection: "Connection"): self._stream = connection._reader @@ -464,7 +431,6 @@ class HiredisParser(BaseParser): kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) - self._socket_timeout = connection.socket_timeout def on_disconnect(self): self._stream = None @@ -475,39 +441,20 @@ class HiredisParser(BaseParser): raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._reader.gets(): return True - return await self.read_from_socket(timeout=0, raise_on_timeout=False) - - async def read_from_socket( - self, - timeout: Union[float, None, _Sentinel] = SENTINEL, - raise_on_timeout: bool = True, - ): - timeout = self._socket_timeout if timeout is SENTINEL else timeout try: - if timeout is None: - buffer = await self._stream.read(self._read_size) - else: - async with async_timeout.timeout(timeout): - buffer = await self._stream.read(self._read_size) - if not buffer or not isinstance(buffer, bytes): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - self._reader.feed(buffer) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - except (socket.timeout, asyncio.TimeoutError): - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") from None + async with async_timeout.timeout(0): + return await self.read_from_socket() + except asyncio.TimeoutError: return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def read_from_socket(self): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True async def read_response( self, disable_decoding: bool = False @@ -922,11 +869,16 @@ class Connection: f"Error while reading from {self.host}:{self.port}: {e.args}" ) - async def read_response(self, disable_decoding: bool = False): + async def read_response( + self, + disable_decoding: bool = False, + timeout: Optional[float] = None, + ): """Read the response from a previously sent command""" + read_timeout = timeout if timeout is not None else self.socket_timeout try: - if self.socket_timeout: - async with async_timeout.timeout(self.socket_timeout): + if read_timeout is not None: + async with async_timeout.timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) @@ -935,6 +887,10 @@ class Connection: disable_decoding=disable_decoding ) except asyncio.TimeoutError: + if timeout is not None: + # user requested timeout, return None + return None + # it was a self.socket_timeout error. await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: diff --git a/redis/client.py b/redis/client.py index 0662a99..75a0dac 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1637,13 +1637,13 @@ class PubSub: if response is not None: yield response - def get_message(self, ignore_subscribe_messages=False, timeout=0): + def get_message(self, ignore_subscribe_messages=False, timeout=0.0): """ Get the next message if one is available, otherwise None. If timeout is specified, the system will wait for `timeout` seconds before returning. Timeout should be specified as a floating point - number. + number, or None, to wait indefinitely. """ if not self.subscribed: # Wait for subscription @@ -1659,7 +1659,7 @@ class PubSub: # so no messages are available return None - response = self.parse_response(block=False, timeout=timeout) + response = self.parse_response(block=(timeout is None), timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) return None diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 86584e4..6dedca9 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -29,7 +29,7 @@ def with_timeout(t): return wrapper -async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): +async def wait_for_message(pubsub, timeout=0.2, ignore_subscribe_messages=False): now = asyncio.get_event_loop().time() timeout = now + timeout while now < timeout: -- cgit v1.2.1