From 54eb398eab4c625bf3696701f27e2a285477e1f5 Mon Sep 17 00:00:00 2001 From: Simon Davy Date: Tue, 28 Apr 2020 18:53:24 +0200 Subject: Make PooledClient and HashClient able to use a custom client class (#280) The user can subclass Client (e.g to add telemetry or logging), and have their custom client class used by PooledClient and HashClient to create new clients. --- pymemcache/client/base.py | 31 +++++++++++++++++-------------- pymemcache/client/hash.py | 4 +++- pymemcache/test/test_client.py | 7 +++++++ pymemcache/test/test_client_hash.py | 16 +++++++++++++--- 4 files changed, 40 insertions(+), 18 deletions(-) (limited to 'pymemcache') diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 0387da4..11a2d09 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -1009,6 +1009,9 @@ class PooledClient(object): in the pool. Your serde object must therefore be thread-safe. """ + #: :class:`Client` class used to create new clients + client_class = Client + def __init__(self, server, serde=None, @@ -1054,20 +1057,20 @@ class PooledClient(object): key_prefix=self.key_prefix) def _create_client(self): - client = Client(self.server, - serde=self.serde, - connect_timeout=self.connect_timeout, - timeout=self.timeout, - no_delay=self.no_delay, - # We need to know when it fails *always* so that we - # can remove/destroy it from the pool... - ignore_exc=False, - socket_module=self.socket_module, - key_prefix=self.key_prefix, - default_noreply=self.default_noreply, - allow_unicode_keys=self.allow_unicode_keys, - tls_context=self.tls_context) - return client + return self.client_class( + self.server, + serde=self.serde, + connect_timeout=self.connect_timeout, + timeout=self.timeout, + no_delay=self.no_delay, + # We need to know when it fails *always* so that we + # can remove/destroy it from the pool... + ignore_exc=False, + socket_module=self.socket_module, + key_prefix=self.key_prefix, + default_noreply=self.default_noreply, + allow_unicode_keys=self.allow_unicode_keys, + tls_context=self.tls_context) def close(self): self.client_pool.clear() diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index 25e297e..8014dd2 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -14,6 +14,8 @@ class HashClient(object): """ A client for communicating with a cluster of memcached servers """ + #: :class:`Client` class used to create new clients + client_class = Client def __init__( self, @@ -112,7 +114,7 @@ class HashClient(object): **self.default_kwargs ) else: - client = Client((server, port), **self.default_kwargs) + client = self.client_class((server, port), **self.default_kwargs) self.clients[key] = client self.hasher.add_node(key) diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 3706922..96e6551 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -1213,6 +1213,13 @@ class TestPooledClient(ClientTestMixin, unittest.TestCase): [b'__FAKE_RESPONSE__\r\n']) self._default_noreply_true('flush_all', (), [b'__FAKE_RESPONSE__\r\n']) + def test_custom_client(self): + class MyClient(Client): + pass + client = PooledClient(('host', 11211)) + client.client_class = MyClient + assert isinstance(client.client_pool.get(), MyClient) + class TestMockClient(ClientTestMixin, unittest.TestCase): def make_client(self, mock_socket_values, **kwargs): diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py index e0edfb7..814661d 100644 --- a/pymemcache/test/test_client_hash.py +++ b/pymemcache/test/test_client_hash.py @@ -39,7 +39,8 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): return client def test_setup_client_without_pooling(self): - with mock.patch('pymemcache.client.hash.Client') as internal_client: + client_class = 'pymemcache.client.hash.HashClient.client_class' + with mock.patch(client_class) as internal_client: client = HashClient([], timeout=999, key_prefix='foo_bar_baz') client.add_server('127.0.0.1', '11211') @@ -295,7 +296,7 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): for client in hash_client.clients.values(): assert client.encoding == encoding - @mock.patch("pymemcache.client.hash.Client") + @mock.patch("pymemcache.client.hash.HashClient.client_class") def test_dead_server_comes_back(self, client_patch): client = HashClient([], dead_timeout=0, retry_attempts=0) client.add_server("127.0.0.1", 11211) @@ -314,7 +315,7 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): assert client.get(b"key") == "Some value" assert ("127.0.0.1", 11211) not in client._dead_clients - @mock.patch("pymemcache.client.hash.Client") + @mock.patch("pymemcache.client.hash.HashClient.client_class") def test_failed_is_retried(self, client_patch): client = HashClient([], retry_attempts=1, retry_timeout=0) client.add_server("127.0.0.1", 11211) @@ -333,4 +334,13 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): assert client_patch.call_count == 1 + def test_custom_client(self): + class MyClient(Client): + pass + + client = HashClient([]) + client.client_class = MyClient + client.add_server('host', 11211) + assert isinstance(client.clients['host:11211'], MyClient) + # TODO: Test failover logic -- cgit v1.2.1