summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter van Dijk <peter.van.dijk@powerdns.com>2016-11-15 00:40:07 +0100
committerAndy McCurdy <andy@andymccurdy.com>2019-12-29 14:05:24 -0800
commitdca7bd40a3a5d0c0853fe2befe706e214407697b (patch)
tree1733abbf899beda9a4cf96ba7c294fede28d6f53
parenta41465e17df3448656da45141b7b14de9d1434eb (diff)
downloadredis-py-dca7bd40a3a5d0c0853fe2befe706e214407697b.tar.gz
Allow setting client_name during connection construction.
Client instances and Connection pools now accept "client_name" as an optional argument. If supplied, all connections created will be named via CLIENT SETNAME once the connection to the server is established.
-rw-r--r--CHANGES4
-rwxr-xr-xredis/client.py3
-rwxr-xr-xredis/connection.py46
-rw-r--r--tests/test_connection_pool.py45
4 files changed, 75 insertions, 23 deletions
diff --git a/CHANGES b/CHANGES
index a8eb848..7cff020 100644
--- a/CHANGES
+++ b/CHANGES
@@ -10,6 +10,10 @@
pipeline instances relied on __len__ for boolean evaluation which
meant that pipelines with no commands on the stack would be considered
False. #994
+ * Client instances and Connection pools now support a 'client_name'
+ argument. If supplied, all connections created will call CLIENT SETNAME
+ as soon as the connection is opened. Thanks to @Habbie for supplying
+ the basis of this chanfge. #802
* 3.3.11
* Further fix for the SSLError -> TimeoutError mapping to work
on obscure releases of Python 2.7.
diff --git a/redis/client.py b/redis/client.py
index 0486022..eb1ccf1 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -684,7 +684,7 @@ class Redis(object):
ssl=False, ssl_keyfile=None, ssl_certfile=None,
ssl_cert_reqs='required', ssl_ca_certs=None,
max_connections=None, single_connection_client=False,
- health_check_interval=0):
+ health_check_interval=0, client_name=None):
if not connection_pool:
if charset is not None:
warnings.warn(DeprecationWarning(
@@ -706,6 +706,7 @@ class Redis(object):
'retry_on_timeout': retry_on_timeout,
'max_connections': max_connections,
'health_check_interval': health_check_interval,
+ 'client_name': client_name
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
diff --git a/redis/connection.py b/redis/connection.py
index b90cafe..9a0e12d 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -485,7 +485,6 @@ else:
class Connection(object):
"Manages TCP communication to and from a Redis server"
- description_format = "Connection<host=%(host)s,port=%(port)s,db=%(db)s>"
def __init__(self, host='localhost', port=6379, db=0, username=None,
password=None, socket_timeout=None,
@@ -494,12 +493,13 @@ class Connection(object):
retry_on_timeout=False, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536,
- health_check_interval=0):
+ health_check_interval=0, client_name=None):
self.pid = os.getpid()
self.host = host
self.port = int(port)
self.db = db
self.username = username
+ self.client_name = client_name
self.password = password
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
@@ -512,16 +512,22 @@ class Connection(object):
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
- self._description_args = {
- 'host': self.host,
- 'port': self.port,
- 'db': self.db,
- }
self._connect_callbacks = []
self._buffer_cutoff = 6000
def __repr__(self):
- return self.description_format % self._description_args
+ repr_args = ','.join(['%s=%s' % (k, v) for k, v in self.repr_pieces()])
+ return '%s<%s>' % (self.__class__.__name__, repr_args)
+
+ def repr_pieces(self):
+ pieces = [
+ ('host', self.host),
+ ('port', self.port),
+ ('db', self.db)
+ ]
+ if self.client_name:
+ pieces.append(('client_name', self.client_name))
+ return pieces
def __del__(self):
try:
@@ -626,6 +632,12 @@ class Connection(object):
if nativestr(self.read_response()) != 'OK':
raise AuthenticationError('Invalid Username or Password')
+ # if a client_name is given, set it
+ if self.client_name:
+ self.send_command('CLIENT', 'SETNAME', self.client_name)
+ if nativestr(self.read_response()) != 'OK':
+ raise ConnectionError('Error setting client name')
+
# if a database is specified, switch to it
if self.db:
self.send_command('SELECT', self.db)
@@ -785,7 +797,6 @@ class Connection(object):
class SSLConnection(Connection):
- description_format = "SSLConnection<host=%(host)s,port=%(port)s,db=%(db)s>"
def __init__(self, ssl_keyfile=None, ssl_certfile=None,
ssl_cert_reqs='required', ssl_ca_certs=None, **kwargs):
@@ -838,18 +849,18 @@ class SSLConnection(Connection):
class UnixDomainSocketConnection(Connection):
- description_format = "UnixDomainSocketConnection<path=%(path)s,db=%(db)s>"
def __init__(self, path='', db=0, username=None, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
retry_on_timeout=False,
parser_class=DefaultParser, socket_read_size=65536,
- health_check_interval=0):
+ health_check_interval=0, client_name=None):
self.pid = os.getpid()
self.path = path
self.db = db
self.username = username
+ self.client_name = client_name
self.password = password
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
@@ -858,13 +869,18 @@ class UnixDomainSocketConnection(Connection):
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
- self._description_args = {
- 'path': self.path,
- 'db': self.db,
- }
self._connect_callbacks = []
self._buffer_cutoff = 6000
+ def repr_pieces(self):
+ pieces = [
+ ('path', self.path),
+ ('db', self.db),
+ ]
+ if self.client_name:
+ pieces.append(('client_name', self.client_name))
+ return pieces
+
def _connect(self):
"Create a Unix domain socket connection"
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py
index e0f0822..7ebd5ff 100644
--- a/tests/test_connection_pool.py
+++ b/tests/test_connection_pool.py
@@ -64,17 +64,28 @@ class TestConnectionPool(object):
assert c1 == c2
def test_repr_contains_db_info_tcp(self):
- connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
+ connection_kwargs = {
+ 'host': 'localhost',
+ 'port': 6379,
+ 'db': 1,
+ 'client_name': 'test-client'
+ }
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.Connection)
- expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=1>>'
+ expected = ('ConnectionPool<Connection<'
+ 'host=localhost,port=6379,db=1,client_name=test-client>>')
assert repr(pool) == expected
def test_repr_contains_db_info_unix(self):
- connection_kwargs = {'path': '/abc', 'db': 1}
+ connection_kwargs = {
+ 'path': '/abc',
+ 'db': 1,
+ 'client_name': 'test-client'
+ }
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.UnixDomainSocketConnection)
- expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
+ expected = ('ConnectionPool<UnixDomainSocketConnection<'
+ 'path=/abc,db=1,client_name=test-client>>')
assert repr(pool) == expected
def test_pool_equality(self):
@@ -177,8 +188,14 @@ class TestBlockingConnectionPool(object):
assert c1 == c2
def test_repr_contains_db_info_tcp(self):
- pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
- expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=0>>'
+ pool = redis.ConnectionPool(
+ host='localhost',
+ port=6379,
+ db=0,
+ client_name='test-client'
+ )
+ expected = ('ConnectionPool<Connection<'
+ 'host=localhost,port=6379,db=0,client_name=test-client>>')
assert repr(pool) == expected
def test_repr_contains_db_info_unix(self):
@@ -186,8 +203,10 @@ class TestBlockingConnectionPool(object):
connection_class=redis.UnixDomainSocketConnection,
path='abc',
db=0,
+ client_name='test-client'
)
- expected = 'ConnectionPool<UnixDomainSocketConnection<path=abc,db=0>>'
+ expected = ('ConnectionPool<UnixDomainSocketConnection<'
+ 'path=abc,db=0,client_name=test-client>>')
assert repr(pool) == expected
@@ -364,6 +383,12 @@ class TestConnectionPoolURLParsing(object):
):
assert expected is to_bool(value)
+ def test_client_name_in_querystring(self):
+ pool = redis.ConnectionPool.from_url(
+ 'redis://location?client_name=test-client'
+ )
+ assert pool.connection_kwargs['client_name'] == 'test-client'
+
def test_invalid_extra_typed_querystring_options(self):
import warnings
with warnings.catch_warnings(record=True) as warning_log:
@@ -502,6 +527,12 @@ class TestConnectionPoolUnixSocketURLParsing(object):
'password': None,
}
+ def test_client_name_in_querystring(self):
+ pool = redis.ConnectionPool.from_url(
+ 'redis://location?client_name=test-client'
+ )
+ assert pool.connection_kwargs['client_name'] == 'test-client'
+
def test_extra_querystring_options(self):
pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2')
assert pool.connection_class == redis.UnixDomainSocketConnection