summaryrefslogtreecommitdiff
path: root/pymemcache
diff options
context:
space:
mode:
authorSimon Davy <bloodearnest@gmail.com>2020-04-28 18:53:24 +0200
committerGitHub <noreply@github.com>2020-04-28 09:53:24 -0700
commit54eb398eab4c625bf3696701f27e2a285477e1f5 (patch)
tree404898ec9980817ccd8ea35aca1ef860a5930819 /pymemcache
parentd030c0e130e452f1aefe406f96cfae98f63e13e3 (diff)
downloadpymemcache-54eb398eab4c625bf3696701f27e2a285477e1f5.tar.gz
Make PooledClient and HashClient able to use a custom client class (#280)
The user can subclass Client (e.g to add telemetry or logging), and have their custom client class used by PooledClient and HashClient to create new clients.
Diffstat (limited to 'pymemcache')
-rw-r--r--pymemcache/client/base.py31
-rw-r--r--pymemcache/client/hash.py4
-rw-r--r--pymemcache/test/test_client.py7
-rw-r--r--pymemcache/test/test_client_hash.py16
4 files changed, 40 insertions, 18 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py
index 0387da4..11a2d09 100644
--- a/pymemcache/client/base.py
+++ b/pymemcache/client/base.py
@@ -1009,6 +1009,9 @@ class PooledClient(object):
in the pool. Your serde object must therefore be thread-safe.
"""
+ #: :class:`Client` class used to create new clients
+ client_class = Client
+
def __init__(self,
server,
serde=None,
@@ -1054,20 +1057,20 @@ class PooledClient(object):
key_prefix=self.key_prefix)
def _create_client(self):
- client = Client(self.server,
- serde=self.serde,
- connect_timeout=self.connect_timeout,
- timeout=self.timeout,
- no_delay=self.no_delay,
- # We need to know when it fails *always* so that we
- # can remove/destroy it from the pool...
- ignore_exc=False,
- socket_module=self.socket_module,
- key_prefix=self.key_prefix,
- default_noreply=self.default_noreply,
- allow_unicode_keys=self.allow_unicode_keys,
- tls_context=self.tls_context)
- return client
+ return self.client_class(
+ self.server,
+ serde=self.serde,
+ connect_timeout=self.connect_timeout,
+ timeout=self.timeout,
+ no_delay=self.no_delay,
+ # We need to know when it fails *always* so that we
+ # can remove/destroy it from the pool...
+ ignore_exc=False,
+ socket_module=self.socket_module,
+ key_prefix=self.key_prefix,
+ default_noreply=self.default_noreply,
+ allow_unicode_keys=self.allow_unicode_keys,
+ tls_context=self.tls_context)
def close(self):
self.client_pool.clear()
diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py
index 25e297e..8014dd2 100644
--- a/pymemcache/client/hash.py
+++ b/pymemcache/client/hash.py
@@ -14,6 +14,8 @@ class HashClient(object):
"""
A client for communicating with a cluster of memcached servers
"""
+ #: :class:`Client` class used to create new clients
+ client_class = Client
def __init__(
self,
@@ -112,7 +114,7 @@ class HashClient(object):
**self.default_kwargs
)
else:
- client = Client((server, port), **self.default_kwargs)
+ client = self.client_class((server, port), **self.default_kwargs)
self.clients[key] = client
self.hasher.add_node(key)
diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py
index 3706922..96e6551 100644
--- a/pymemcache/test/test_client.py
+++ b/pymemcache/test/test_client.py
@@ -1213,6 +1213,13 @@ class TestPooledClient(ClientTestMixin, unittest.TestCase):
[b'__FAKE_RESPONSE__\r\n'])
self._default_noreply_true('flush_all', (), [b'__FAKE_RESPONSE__\r\n'])
+ def test_custom_client(self):
+ class MyClient(Client):
+ pass
+ client = PooledClient(('host', 11211))
+ client.client_class = MyClient
+ assert isinstance(client.client_pool.get(), MyClient)
+
class TestMockClient(ClientTestMixin, unittest.TestCase):
def make_client(self, mock_socket_values, **kwargs):
diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py
index e0edfb7..814661d 100644
--- a/pymemcache/test/test_client_hash.py
+++ b/pymemcache/test/test_client_hash.py
@@ -39,7 +39,8 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
return client
def test_setup_client_without_pooling(self):
- with mock.patch('pymemcache.client.hash.Client') as internal_client:
+ client_class = 'pymemcache.client.hash.HashClient.client_class'
+ with mock.patch(client_class) as internal_client:
client = HashClient([], timeout=999, key_prefix='foo_bar_baz')
client.add_server('127.0.0.1', '11211')
@@ -295,7 +296,7 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
for client in hash_client.clients.values():
assert client.encoding == encoding
- @mock.patch("pymemcache.client.hash.Client")
+ @mock.patch("pymemcache.client.hash.HashClient.client_class")
def test_dead_server_comes_back(self, client_patch):
client = HashClient([], dead_timeout=0, retry_attempts=0)
client.add_server("127.0.0.1", 11211)
@@ -314,7 +315,7 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
assert client.get(b"key") == "Some value"
assert ("127.0.0.1", 11211) not in client._dead_clients
- @mock.patch("pymemcache.client.hash.Client")
+ @mock.patch("pymemcache.client.hash.HashClient.client_class")
def test_failed_is_retried(self, client_patch):
client = HashClient([], retry_attempts=1, retry_timeout=0)
client.add_server("127.0.0.1", 11211)
@@ -333,4 +334,13 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
assert client_patch.call_count == 1
+ def test_custom_client(self):
+ class MyClient(Client):
+ pass
+
+ client = HashClient([])
+ client.client_class = MyClient
+ client.add_server('host', 11211)
+ assert isinstance(client.clients['host:11211'], MyClient)
+
# TODO: Test failover logic