summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwoutdenolf <wout.de_nolf@esrf.eu>2023-03-16 14:08:27 +0100
committerGitHub <noreply@github.com>2023-03-16 15:08:27 +0200
commit7d474f90453c7b90bd06c94e0250b618120a599d (patch)
tree82d1640999dc4db479cec29af8abfd780f2ac3e7
parentc87172347584301f453c601c483126e4800257b7 (diff)
downloadredis-py-7d474f90453c7b90bd06c94e0250b618120a599d.tar.gz
introduce AbstractConnection so that UnixDomainSocketConnection can call super().__init__ (#2588)
Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
-rw-r--r--redis/connection.py278
1 files changed, 120 insertions, 158 deletions
diff --git a/redis/connection.py b/redis/connection.py
index c4a9685..faea768 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -6,6 +6,7 @@ import socket
import sys
import threading
import weakref
+from abc import abstractmethod
from io import SEEK_END
from itertools import chain
from queue import Empty, Full, LifoQueue
@@ -583,20 +584,13 @@ class PythonRespSerializer:
return output
-class Connection:
- "Manages TCP communication to and from a Redis server"
+class AbstractConnection:
+ "Manages communication to and from a Redis server"
def __init__(
self,
- host="localhost",
- port=6379,
db=0,
password=None,
- socket_timeout=None,
- socket_connect_timeout=None,
- socket_keepalive=False,
- socket_keepalive_options=None,
- socket_type=0,
retry_on_timeout=False,
retry_on_error=SENTINEL,
encoding="utf-8",
@@ -627,18 +621,11 @@ class Connection:
"2. 'credential_provider'"
)
self.pid = os.getpid()
- self.host = host
- self.port = int(port)
self.db = db
self.client_name = client_name
self.credential_provider = credential_provider
self.password = password
self.username = username
- self.socket_timeout = socket_timeout
- self.socket_connect_timeout = socket_connect_timeout or socket_timeout
- self.socket_keepalive = socket_keepalive
- self.socket_keepalive_options = socket_keepalive_options or {}
- self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
@@ -671,11 +658,9 @@ class Connection:
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
return f"{self.__class__.__name__}<{repr_args}>"
+ @abstractmethod
def repr_pieces(self):
- pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
- if self.client_name:
- pieces.append(("client_name", self.client_name))
- return pieces
+ pass
def __del__(self):
try:
@@ -738,75 +723,17 @@ class Connection:
if callback:
callback(self)
+ @abstractmethod
def _connect(self):
- "Create a TCP socket connection"
- # we want to mimic what socket.create_connection does to support
- # ipv4/ipv6, but we want to set options prior to calling
- # socket.connect()
- err = None
- for res in socket.getaddrinfo(
- self.host, self.port, self.socket_type, socket.SOCK_STREAM
- ):
- family, socktype, proto, canonname, socket_address = res
- sock = None
- try:
- sock = socket.socket(family, socktype, proto)
- # TCP_NODELAY
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-
- # TCP_KEEPALIVE
- if self.socket_keepalive:
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- for k, v in self.socket_keepalive_options.items():
- sock.setsockopt(socket.IPPROTO_TCP, k, v)
-
- # set the socket_connect_timeout before we connect
- sock.settimeout(self.socket_connect_timeout)
-
- # connect
- sock.connect(socket_address)
-
- # set the socket_timeout now that we're connected
- sock.settimeout(self.socket_timeout)
- return sock
-
- except OSError as _:
- err = _
- if sock is not None:
- sock.close()
-
- if err is not None:
- raise err
- raise OSError("socket.getaddrinfo returned an empty list")
+ pass
+ @abstractmethod
def _host_error(self):
- try:
- host_error = f"{self.host}:{self.port}"
- except AttributeError:
- host_error = "connection"
-
- return host_error
+ pass
+ @abstractmethod
def _error_message(self, exception):
- # args for socket.error can either be (errno, "message")
- # or just "message"
-
- host_error = self._host_error()
-
- if len(exception.args) == 1:
- try:
- return f"Error connecting to {host_error}. \
- {exception.args[0]}."
- except AttributeError:
- return f"Connection Error: {exception.args[0]}"
- else:
- try:
- return (
- f"Error {exception.args[0]} connecting to "
- f"{host_error}. {exception.args[1]}."
- )
- except AttributeError:
- return f"Connection Error: {exception.args[0]}"
+ pass
def on_connect(self):
"Initialize the connection, authenticate and select a database"
@@ -990,6 +917,101 @@ class Connection:
return output
+class Connection(AbstractConnection):
+ "Manages TCP communication to and from a Redis server"
+
+ def __init__(
+ self,
+ host="localhost",
+ port=6379,
+ socket_timeout=None,
+ socket_connect_timeout=None,
+ socket_keepalive=False,
+ socket_keepalive_options=None,
+ socket_type=0,
+ **kwargs,
+ ):
+ self.host = host
+ self.port = int(port)
+ self.socket_timeout = socket_timeout
+ self.socket_connect_timeout = socket_connect_timeout or socket_timeout
+ self.socket_keepalive = socket_keepalive
+ self.socket_keepalive_options = socket_keepalive_options or {}
+ self.socket_type = socket_type
+ super().__init__(**kwargs)
+
+ def repr_pieces(self):
+ pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
+ if self.client_name:
+ pieces.append(("client_name", self.client_name))
+ return pieces
+
+ def _connect(self):
+ "Create a TCP socket connection"
+ # we want to mimic what socket.create_connection does to support
+ # ipv4/ipv6, but we want to set options prior to calling
+ # socket.connect()
+ err = None
+ for res in socket.getaddrinfo(
+ self.host, self.port, self.socket_type, socket.SOCK_STREAM
+ ):
+ family, socktype, proto, canonname, socket_address = res
+ sock = None
+ try:
+ sock = socket.socket(family, socktype, proto)
+ # TCP_NODELAY
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+ # TCP_KEEPALIVE
+ if self.socket_keepalive:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ for k, v in self.socket_keepalive_options.items():
+ sock.setsockopt(socket.IPPROTO_TCP, k, v)
+
+ # set the socket_connect_timeout before we connect
+ sock.settimeout(self.socket_connect_timeout)
+
+ # connect
+ sock.connect(socket_address)
+
+ # set the socket_timeout now that we're connected
+ sock.settimeout(self.socket_timeout)
+ return sock
+
+ except OSError as _:
+ err = _
+ if sock is not None:
+ sock.close()
+
+ if err is not None:
+ raise err
+ raise OSError("socket.getaddrinfo returned an empty list")
+
+ def _host_error(self):
+ return f"{self.host}:{self.port}"
+
+ def _error_message(self, exception):
+ # args for socket.error can either be (errno, "message")
+ # or just "message"
+
+ host_error = self._host_error()
+
+ if len(exception.args) == 1:
+ try:
+ return f"Error connecting to {host_error}. \
+ {exception.args[0]}."
+ except AttributeError:
+ return f"Connection Error: {exception.args[0]}"
+ else:
+ try:
+ return (
+ f"Error {exception.args[0]} connecting to "
+ f"{host_error}. {exception.args[1]}."
+ )
+ except AttributeError:
+ return f"Connection Error: {exception.args[0]}"
+
+
class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
This class extends the Connection class, adding SSL functionality, and making
@@ -1035,8 +1057,6 @@ class SSLConnection(Connection):
if not ssl_available:
raise RedisError("Python wasn't built with SSL support")
- super().__init__(**kwargs)
-
self.keyfile = ssl_keyfile
self.certfile = ssl_certfile
if ssl_cert_reqs is None:
@@ -1062,6 +1082,7 @@ class SSLConnection(Connection):
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
self.ssl_ocsp_context = ssl_ocsp_context
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
+ super().__init__(**kwargs)
def _connect(self):
"Wrap the socket with SSL support"
@@ -1131,77 +1152,12 @@ class SSLConnection(Connection):
return sslsock
-class UnixDomainSocketConnection(Connection):
- def __init__(
- self,
- path="",
- db=0,
- username=None,
- password=None,
- socket_timeout=None,
- encoding="utf-8",
- encoding_errors="strict",
- decode_responses=False,
- retry_on_timeout=False,
- retry_on_error=SENTINEL,
- parser_class=DefaultParser,
- socket_read_size=65536,
- health_check_interval=0,
- client_name=None,
- retry=None,
- redis_connect_func=None,
- credential_provider: Optional[CredentialProvider] = None,
- command_packer=None,
- ):
- """
- Initialize a new UnixDomainSocketConnection.
- To specify a retry policy for specific errors, first set
- `retry_on_error` to a list of the error/s to retry on, then set
- `retry` to a valid `Retry` object.
- To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
- """
- if (username or password) and credential_provider is not None:
- raise DataError(
- "'username' and 'password' cannot be passed along with 'credential_"
- "provider'. Please provide only one of the following arguments: \n"
- "1. 'password' and (optional) 'username'\n"
- "2. 'credential_provider'"
- )
- self.pid = os.getpid()
+class UnixDomainSocketConnection(AbstractConnection):
+ "Manages UDS communication to and from a Redis server"
+
+ def __init__(self, path="", **kwargs):
self.path = path
- self.db = db
- self.client_name = client_name
- self.credential_provider = credential_provider
- self.password = password
- self.username = username
- self.socket_timeout = socket_timeout
- self.retry_on_timeout = retry_on_timeout
- if retry_on_error is SENTINEL:
- retry_on_error = []
- if retry_on_timeout:
- # Add TimeoutError to the errors list to retry on
- retry_on_error.append(TimeoutError)
- self.retry_on_error = retry_on_error
- if self.retry_on_error:
- if retry is None:
- self.retry = Retry(NoBackoff(), 1)
- else:
- # deep-copy the Retry object as it is mutable
- self.retry = copy.deepcopy(retry)
- # Update the retry's supported errors with the specified errors
- self.retry.update_supported_errors(retry_on_error)
- else:
- self.retry = Retry(NoBackoff(), 0)
- self.health_check_interval = health_check_interval
- self.next_health_check = 0
- self.redis_connect_func = redis_connect_func
- self.encoder = Encoder(encoding, encoding_errors, decode_responses)
- self._sock = None
- self._socket_read_size = socket_read_size
- self.set_parser(parser_class)
- self._connect_callbacks = []
- self._buffer_cutoff = 6000
- self._command_packer = self._construct_command_packer(command_packer)
+ super().__init__(**kwargs)
def repr_pieces(self):
pieces = [("path", self.path), ("db", self.db)]
@@ -1216,15 +1172,21 @@ class UnixDomainSocketConnection(Connection):
sock.connect(self.path)
return sock
+ def _host_error(self):
+ return self.path
+
def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"
+ host_error = self._host_error()
if len(exception.args) == 1:
- return f"Error connecting to unix socket: {self.path}. {exception.args[0]}."
+ return (
+ f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
+ )
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
- f"{self.path}. {exception.args[1]}."
+ f"{host_error}. {exception.args[1]}."
)