diff options
author | Nicholas Charriere <nicholascharriere@gmail.com> | 2016-08-03 10:55:46 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-03 10:55:46 -0700 |
commit | af1e26f6baa1dea0ea470ee4466723e62526150b (patch) | |
tree | 65fa18bb7fe74dce7bc7c100fcebc906293bb7e5 | |
parent | 371e673cf733caa2f42ee10173ac51770a6aa9b2 (diff) | |
parent | 7bea08f3bbf67455f58cf9bbed1e4437fbb45859 (diff) | |
download | pymemcache-af1e26f6baa1dea0ea470ee4466723e62526150b.tar.gz |
Merge pull request #107 from Morreski/master
Added default parameter support for "get" and "gets" methods in Client
-rw-r--r-- | pymemcache/client/base.py | 21 | ||||
-rw-r--r-- | pymemcache/test/test_client.py | 10 | ||||
-rw-r--r-- | pymemcache/test/utils.py | 6 |
3 files changed, 26 insertions, 11 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index ade1a18..2d48abf 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -377,17 +377,18 @@ class Client(object): """ return self._store_cmd(b'cas', key, expire, noreply, value, cas) - def get(self, key): + def get(self, key, default=None): """ The memcached "get" command, but only for one key, as a convenience. Args: key: str, see class docs for details. + default: value that will be returned if the key was not found. Returns: - The value for the key, or None if the key wasn't found. + The value for the key, or default if the key wasn't found. """ - return self._fetch_cmd(b'get', [key], False).get(key, None) + return self._fetch_cmd(b'get', [key], False).get(key, default) def get_many(self, keys): """ @@ -408,17 +409,21 @@ class Client(object): get_multi = get_many - def gets(self, key): + def gets(self, key, default=None, cas_default=None): """ The memcached "gets" command for one key, as a convenience. Args: key: str, see class docs for details. + default: value that will be returned if the key was not found. + cas_default: same behaviour as default argument. Returns: - A tuple of (key, cas), or (None, None) if the key was not found. + A tuple of (key, cas) + or (default, cas_defaults) if the key was not found. """ - return self._fetch_cmd(b'gets', [key], True).get(key, (None, None)) + defaults = (default, cas_default) + return self._fetch_cmd(b'gets', [key], True).get(key, defaults) def gets_many(self, keys): """ @@ -887,10 +892,10 @@ class PooledClient(object): return client.cas(key, value, cas, expire=expire, noreply=noreply) - def get(self, key): + def get(self, key, default=None): with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: - return client.get(key) + return client.get(key, default) except Exception: if self.ignore_exc: return None diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 7a641d8..78d4f7a 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -134,6 +134,11 @@ class ClientTestMixin(object): result = client.get(b'key') assert result is None + def test_get_not_found_default(self): + client = self.make_client([b'END\r\n']) + result = client.get(b'key', default='foobar') + assert result is 'foobar' + def test_get_found(self): client = self.make_client([ b'STORED\r\n', @@ -402,6 +407,11 @@ class TestClient(ClientTestMixin, unittest.TestCase): result = client.gets(b'key') assert result == (None, None) + def test_gets_not_found_defaults(self): + client = self.make_client([b'END\r\n']) + result = client.gets(b'key', default='foo', cas_default='bar') + assert result == ('foo', 'bar') + def test_gets_found(self): client = self.make_client([b'VALUE key 0 5 10\r\nvalue\r\nEND\r\n']) result = client.gets(b'key') diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py index f7cc1d8..997a58e 100644 --- a/pymemcache/test/utils.py +++ b/pymemcache/test/utils.py @@ -40,17 +40,17 @@ class MockMemcacheClient(object): self.no_delay = no_delay self.ignore_exc = ignore_exc - def get(self, key): + def get(self, key, default=None): if isinstance(key, six.text_type): raise MemcacheIllegalInputError(key) if key not in self._contents: - return None + return default expire, value, was_serialized = self._contents[key] if expire and expire < time.time(): del self._contents[key] - return None + return default if self.deserializer: return self.deserializer(key, value, 2 if was_serialized else 1) |