summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pymemcache/client/hash.py68
-rw-r--r--pymemcache/test/conftest.py11
-rw-r--r--pymemcache/test/test_integration.py33
3 files changed, 88 insertions, 24 deletions
diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py
index 73a24e3..c4e9adb 100644
--- a/pymemcache/client/hash.py
+++ b/pymemcache/client/hash.py
@@ -128,12 +128,8 @@ class HashClient(object):
client = self.clients[server]
return client
- def _run_cmd(self, cmd, key, *args, **kwargs):
+ def _safely_run_func(self, client, func, *args, **kwargs):
try:
- can_run = True
- client = self._get_client(key)
- func = getattr(client, cmd)
-
if client.server in self._failed_clients:
# This server is currently failing, lets check if it is in retry
# or marked as dead
@@ -146,7 +142,7 @@ class HashClient(object):
if time.time() - failed_time > self.retry_timeout:
print(failed_metadata)
print('retrying')
- result = func(key, *args, **kwargs)
+ result = func(*args, **kwargs)
# we were successful, lets remove it from the failed
# clients
self._failed_clients.pop(client.server)
@@ -158,7 +154,7 @@ class HashClient(object):
print('marking as dead')
self._remove_server(*client.server)
- result = func(key, *args, **kwargs)
+ result = func(*args, **kwargs)
return result
# Connecting to the server fail, we should enter
@@ -194,8 +190,66 @@ class HashClient(object):
failed_metadata['failed_time'] = time.time()
self._failed_clients[client.server] = failed_metadata
+ def _run_cmd(self, cmd, key, *args, **kwargs):
+ client = self._get_client(key)
+ func = getattr(client, cmd)
+ args = list(args)
+ args.insert(0, key)
+ return self._safely_run_func(client, func, *args, **kwargs)
+
def set(self, key, *args, **kwargs):
return self._run_cmd('set', key, *args, **kwargs)
def get(self, key, *args, **kwargs):
return self._run_cmd('get', key, *args, **kwargs)
+
+ def get(self, key, *args, **kwargs):
+ return self._run_cmd('get', key, *args, **kwargs)
+
+ def get_many(self, keys, *args, **kwargs):
+ client_batches = {}
+ for key in keys:
+ client = self._get_client(key)
+
+ if client.server not in client_batches:
+ client_batches[client.server] = []
+
+ client_batches[client.server].append(key)
+
+ end = {}
+
+ for server, keys in client_batches.items():
+ client = self.clients['%s:%s' % server]
+ new_args = [keys] + list(args)
+ result = self._safely_run_func(
+ client,
+ client.get_many, *new_args, **kwargs
+ )
+ end.update(result)
+
+ return end
+
+ def gets(self, key, *args, **kwargs):
+ return self._run_cmd('gets', key, *args, **kwargs)
+
+ def add(self, key, *args, **kwargs):
+ return self._run_cmd('add', key, *args, **kwargs)
+
+ def prepend(self, key, *args, **kwargs):
+ return self._run_cmd('prepend', key, *args, **kwargs)
+
+ def append(self, key, *args, **kwargs):
+ return self._run_cmd('append', key, *args, **kwargs)
+
+ def delete(self, key, *args, **kwargs):
+ return self._run_cmd('delete', key, *args, **kwargs)
+
+ def cas(self, key, *args, **kwargs):
+ return self._run_cmd('cas', key, *args, **kwargs)
+
+ def replace(self, key, *args, **kwargs):
+ return self._run_cmd('replace', key, *args, **kwargs)
+
+ def flush_all(self):
+ for _, client in self.clients.items():
+ self._safely_run_func(client, client.flush_all)
diff --git a/pymemcache/test/conftest.py b/pymemcache/test/conftest.py
index 8792cc3..5d1f9ce 100644
--- a/pymemcache/test/conftest.py
+++ b/pymemcache/test/conftest.py
@@ -33,3 +33,14 @@ def pytest_generate_tests(metafunc):
socket_modules.append(gevent_socket)
metafunc.parametrize("socket_module", socket_modules)
+
+ if 'client_class' in metafunc.fixturenames:
+ from pymemcache.client.base import PooledClient, Client
+ from pymemcache.client.hash import HashClient
+ class HashClientSingle(HashClient):
+ def __init__(self, server, *args, **kwargs):
+ super(HashClientSingle, self).__init__([server], *args, **kwargs)
+
+ metafunc.parametrize(
+ "client_class", [Client, PooledClient, HashClientSingle]
+ )
diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py
index 7ac80a9..491512a 100644
--- a/pymemcache/test/test_integration.py
+++ b/pymemcache/test/test_integration.py
@@ -24,8 +24,8 @@ from pymemcache.exceptions import (
@pytest.mark.integration()
-def test_get_set(host, port, socket_module):
- client = Client((host, port), socket_module=socket_module)
+def test_get_set(client_class, host, port, socket_module):
+ client = client_class((host, port), socket_module=socket_module)
client.flush_all()
result = client.get('key')
@@ -47,8 +47,8 @@ def test_get_set(host, port, socket_module):
@pytest.mark.integration()
-def test_add_replace(host, port, socket_module):
- client = Client((host, port), socket_module=socket_module)
+def test_add_replace(client_class, host, port, socket_module):
+ client = client_class((host, port), socket_module=socket_module)
client.flush_all()
result = client.add(b'key', b'value', noreply=False)
@@ -73,8 +73,8 @@ def test_add_replace(host, port, socket_module):
@pytest.mark.integration()
-def test_append_prepend(host, port, socket_module):
- client = Client((host, port), socket_module=socket_module)
+def test_append_prepend(client_class, host, port, socket_module):
+ client = client_class((host, port), socket_module=socket_module)
client.flush_all()
result = client.append(b'key', b'value', noreply=False)
@@ -101,10 +101,9 @@ def test_append_prepend(host, port, socket_module):
@pytest.mark.integration()
-def test_cas(host, port, socket_module):
- client = Client((host, port), socket_module=socket_module)
+def test_cas(client_class, host, port, socket_module):
+ client = client_class((host, port), socket_module=socket_module)
client.flush_all()
-
result = client.cas(b'key', b'value', b'1', noreply=False)
assert result is None
@@ -125,8 +124,8 @@ def test_cas(host, port, socket_module):
@pytest.mark.integration()
-def test_gets(host, port, socket_module):
- client = Client((host, port), socket_module=socket_module)
+def test_gets(client_class, host, port, socket_module):
+ client = client_class((host, port), socket_module=socket_module)
client.flush_all()
result = client.gets(b'key')
@@ -139,8 +138,8 @@ def test_gets(host, port, socket_module):
@pytest.mark.delete()
-def test_delete(host, port, socket_module):
- client = Client((host, port), socket_module=socket_module)
+def test_delete(client_class, host, port, socket_module):
+ client = client_class((host, port), socket_module=socket_module)
client.flush_all()
result = client.delete(b'key', noreply=False)
@@ -157,7 +156,7 @@ def test_delete(host, port, socket_module):
@pytest.mark.integration()
-def test_incr_decr(host, port, socket_module):
+def test_incr_decr(client_class, host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
@@ -185,7 +184,7 @@ def test_incr_decr(host, port, socket_module):
@pytest.mark.integration()
-def test_misc(host, port, socket_module):
+def test_misc(client_class, host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
@@ -211,8 +210,8 @@ def test_serialization_deserialization(host, port, socket_module):
@pytest.mark.integration()
-def test_errors(host, port, socket_module):
- client = Client((host, port), socket_module=socket_module)
+def test_errors(client_class, host, port, socket_module):
+ client = client_class((host, port), socket_module=socket_module)
client.flush_all()
def _key_with_ws():