summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Charriere <nicholascharriere@gmail.com>2016-08-03 10:55:46 -0700
committerGitHub <noreply@github.com>2016-08-03 10:55:46 -0700
commitaf1e26f6baa1dea0ea470ee4466723e62526150b (patch)
tree65fa18bb7fe74dce7bc7c100fcebc906293bb7e5
parent371e673cf733caa2f42ee10173ac51770a6aa9b2 (diff)
parent7bea08f3bbf67455f58cf9bbed1e4437fbb45859 (diff)
downloadpymemcache-af1e26f6baa1dea0ea470ee4466723e62526150b.tar.gz
Merge pull request #107 from Morreski/master
Added default parameter support for "get" and "gets" methods in Client
-rw-r--r--pymemcache/client/base.py21
-rw-r--r--pymemcache/test/test_client.py10
-rw-r--r--pymemcache/test/utils.py6
3 files changed, 26 insertions, 11 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py
index ade1a18..2d48abf 100644
--- a/pymemcache/client/base.py
+++ b/pymemcache/client/base.py
@@ -377,17 +377,18 @@ class Client(object):
"""
return self._store_cmd(b'cas', key, expire, noreply, value, cas)
- def get(self, key):
+ def get(self, key, default=None):
"""
The memcached "get" command, but only for one key, as a convenience.
Args:
key: str, see class docs for details.
+ default: value that will be returned if the key was not found.
Returns:
- The value for the key, or None if the key wasn't found.
+ The value for the key, or default if the key wasn't found.
"""
- return self._fetch_cmd(b'get', [key], False).get(key, None)
+ return self._fetch_cmd(b'get', [key], False).get(key, default)
def get_many(self, keys):
"""
@@ -408,17 +409,21 @@ class Client(object):
get_multi = get_many
- def gets(self, key):
+ def gets(self, key, default=None, cas_default=None):
"""
The memcached "gets" command for one key, as a convenience.
Args:
key: str, see class docs for details.
+ default: value that will be returned if the key was not found.
+ cas_default: same behaviour as default argument.
Returns:
- A tuple of (key, cas), or (None, None) if the key was not found.
+ A tuple of (key, cas)
+ or (default, cas_defaults) if the key was not found.
"""
- return self._fetch_cmd(b'gets', [key], True).get(key, (None, None))
+ defaults = (default, cas_default)
+ return self._fetch_cmd(b'gets', [key], True).get(key, defaults)
def gets_many(self, keys):
"""
@@ -887,10 +892,10 @@ class PooledClient(object):
return client.cas(key, value, cas,
expire=expire, noreply=noreply)
- def get(self, key):
+ def get(self, key, default=None):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
try:
- return client.get(key)
+ return client.get(key, default)
except Exception:
if self.ignore_exc:
return None
diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py
index 7a641d8..78d4f7a 100644
--- a/pymemcache/test/test_client.py
+++ b/pymemcache/test/test_client.py
@@ -134,6 +134,11 @@ class ClientTestMixin(object):
result = client.get(b'key')
assert result is None
+ def test_get_not_found_default(self):
+ client = self.make_client([b'END\r\n'])
+ result = client.get(b'key', default='foobar')
+ assert result is 'foobar'
+
def test_get_found(self):
client = self.make_client([
b'STORED\r\n',
@@ -402,6 +407,11 @@ class TestClient(ClientTestMixin, unittest.TestCase):
result = client.gets(b'key')
assert result == (None, None)
+ def test_gets_not_found_defaults(self):
+ client = self.make_client([b'END\r\n'])
+ result = client.gets(b'key', default='foo', cas_default='bar')
+ assert result == ('foo', 'bar')
+
def test_gets_found(self):
client = self.make_client([b'VALUE key 0 5 10\r\nvalue\r\nEND\r\n'])
result = client.gets(b'key')
diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py
index f7cc1d8..997a58e 100644
--- a/pymemcache/test/utils.py
+++ b/pymemcache/test/utils.py
@@ -40,17 +40,17 @@ class MockMemcacheClient(object):
self.no_delay = no_delay
self.ignore_exc = ignore_exc
- def get(self, key):
+ def get(self, key, default=None):
if isinstance(key, six.text_type):
raise MemcacheIllegalInputError(key)
if key not in self._contents:
- return None
+ return default
expire, value, was_serialized = self._contents[key]
if expire and expire < time.time():
del self._contents[key]
- return None
+ return default
if self.deserializer:
return self.deserializer(key, value, 2 if was_serialized else 1)