summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Pope <nick.pope@flightdataservices.com>2021-05-07 20:22:34 +0100
committerNick Pope <nick.pope@flightdataservices.com>2021-05-10 18:18:09 +0100
commit1b471a633325cc9e0ee836778b820d658537c480 (patch)
tree9aeba1a1911dc146a1e02f432de32c8a82dbdc44
parent9e2451a3831e216ce726bf2b12f6a68d61e5cbb6 (diff)
downloadpymemcache-1b471a633325cc9e0ee836778b820d658537c480.tar.gz
Fixed `HashClient.{get,set}_many()` with UNIX sockets.
Overlooked when UNIX socket support was added to `HashClient` in acd962b586a4fe018a00acb7e1621be373c13c3b. Fixes #314.
-rw-r--r--pymemcache/client/hash.py25
-rw-r--r--pymemcache/test/test_client_hash.py54
2 files changed, 64 insertions, 15 deletions
diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py
index 9e273ee..6ec8d1f 100644
--- a/pymemcache/client/hash.py
+++ b/pymemcache/client/hash.py
@@ -112,6 +112,11 @@ class HashClient(object):
self.encoding = encoding
self.tls_context = tls_context
+ def _make_client_key(self, server):
+ if isinstance(server, (list, tuple)) and len(server) == 2:
+ return '%s:%s' % server
+ return server
+
def add_server(self, server, port=None):
# To maintain backward compatibility, if a port is provided, assume
# that server wasn't provided as a (host, port) tuple.
@@ -120,16 +125,12 @@ class HashClient(object):
raise TypeError('Server must be a string when passing port.')
server = (server, port)
- if isinstance(server, six.string_types):
- key = server
- else:
- key = '%s:%s' % server
-
_class = PooledClient if self.use_pooling else self.client_class
client = _class(server, **self.default_kwargs)
if self.use_pooling:
client.client_class = self.client_class
+ key = self._make_client_key(server)
self.clients[key] = client
self.hasher.add_node(key)
@@ -141,11 +142,7 @@ class HashClient(object):
raise TypeError('Server must be a string when passing port.')
server = (server, port)
- if isinstance(server, six.string_types):
- key = server
- else:
- key = '%s:%s' % server
-
+ key = self._make_client_key(server)
dead_time = time.time()
self._failed_clients.pop(server)
self._dead_clients[server] = dead_time
@@ -181,8 +178,7 @@ class HashClient(object):
return
raise MemcacheError('All servers seem to be down right now')
- client = self.clients[server]
- return client
+ return self.clients[server]
def _safely_run_func(self, client, func, default_val, *args, **kwargs):
try:
@@ -383,8 +379,7 @@ class HashClient(object):
client_batches[client.server][key] = value
for server, values in client_batches.items():
- client = self.clients['%s:%s' % server]
-
+ client = self.clients[self._make_client_key(server)]
failed += self._safely_run_set_many(
client, values, *args, **kwargs
)
@@ -406,7 +401,7 @@ class HashClient(object):
client_batches[client.server].append(key)
for server, keys in client_batches.items():
- client = self.clients['%s:%s' % server]
+ client = self.clients[self._make_client_key(server)]
new_args = list(args)
new_args.insert(0, keys)
diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py
index 5dd4ec4..04b5123 100644
--- a/pymemcache/test/test_client_hash.py
+++ b/pymemcache/test/test_client_hash.py
@@ -39,6 +39,20 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
return client
+ def make_unix_client(self, sockets, *mock_socket_values, **kwargs):
+ client = HashClient([], **kwargs)
+
+ for socket_, vals in zip(sockets, mock_socket_values):
+ c = self.make_client_pool(
+ socket_,
+ vals,
+ **kwargs
+ )
+ client.clients[socket_] = c
+ client.hasher.add_node(socket_)
+
+ return client
+
def test_setup_client_without_pooling(self):
client_class = 'pymemcache.client.hash.HashClient.client_class'
with mock.patch(client_class) as internal_client:
@@ -50,6 +64,30 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
assert kwargs['timeout'] == 999
assert kwargs['key_prefix'] == 'foo_bar_baz'
+ def test_get_many_unix(self):
+ pid = os.getpid()
+ sockets = [
+ '/tmp/pymemcache.1.%d' % pid,
+ '/tmp/pymemcache.2.%d' % pid,
+ ]
+ client = self.make_unix_client(sockets, *[
+ [b'STORED\r\n', b'VALUE key3 0 6\r\nvalue2\r\nEND\r\n', ],
+ [b'STORED\r\n', b'VALUE key1 0 6\r\nvalue1\r\nEND\r\n', ],
+ ])
+
+ def get_clients(key):
+ if key == b'key3':
+ return client.clients['/tmp/pymemcache.1.%d' % pid]
+ else:
+ return client.clients['/tmp/pymemcache.2.%d' % pid]
+
+ client._get_client = get_clients
+
+ result = client.set(b'key1', b'value1', noreply=False)
+ result = client.set(b'key3', b'value2', noreply=False)
+ result = client.get_many([b'key1', b'key3'])
+ assert result == {b'key1': b'value1', b'key3': b'value2'}
+
def test_get_many_all_found(self):
client = self.make_client(*[
[b'STORED\r\n', b'VALUE key3 0 6\r\nvalue2\r\nEND\r\n', ],
@@ -284,6 +322,22 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
result = client.set_many(values, noreply=True)
assert result == []
+ def test_set_many_unix(self):
+ values = {
+ 'key1': 'value1',
+ 'key2': 'value2',
+ 'key3': 'value3'
+ }
+
+ pid = os.getpid()
+ sockets = ['/tmp/pymemcache.%d' % pid]
+ client = self.make_unix_client(sockets, *[
+ [b'STORED\r\n', b'NOT_STORED\r\n', b'STORED\r\n'],
+ ])
+
+ result = client.set_many(values, noreply=False)
+ assert result == ['key2']
+
def test_server_encoding_pooled(self):
"""
test passed encoding from hash client to pooled clients