From 7d871868199133818d821cf17d20fdb5727b6ace Mon Sep 17 00:00:00 2001 From: Matej Spiller Muys Date: Fri, 20 Jan 2023 09:36:57 +0100 Subject: Support for Gat and Gats, Support passing server key in a tuple --- README.rst | 1 + pymemcache/client/base.py | 64 +++++++++++++++++++++++++++++ pymemcache/client/hash.py | 26 ++++++++---- pymemcache/test/test_client.py | 26 ++++++++++++ pymemcache/test/test_client_hash.py | 81 +++++++++++++++++++++++++------------ pymemcache/test/test_integration.py | 31 ++++++++++++++ 6 files changed, 197 insertions(+), 32 deletions(-) diff --git a/README.rst b/README.rst index 1fcd951..08ee9c8 100644 --- a/README.rst +++ b/README.rst @@ -136,6 +136,7 @@ Credits * `Nick Pope `_ * `Hervé Beraud `_ * `Martin Jørgensen `_ +* `Matej Spiller Muys `_ We're Hiring! ============= diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 04ae052..252e860 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -688,6 +688,23 @@ class Client: key, default ) + def gat(self, key: Key, expire: int = 0, default: Optional[Any] = None) -> Any: + """ + The memcached "gat" command, but only for one key, as a convenience. + + Args: + key: str, see class docs for details. + expire: optional int, number of seconds until the item is expired + from the cache, or zero for no expiry (the default). + default: value that will be returned if the key was not found. + + Returns: + The value for the key, or default if the key wasn't found. + """ + return self._fetch_cmd( + b"gat", [key], False, key_prefix=self.key_prefix, expire=expire + ).get(key, default) + def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]: """ The memcached "get" command. @@ -727,6 +744,28 @@ class Client: key, defaults ) + def gats( + self, key: Key, expire: int = 0, default: Any = None, cas_default: Any = None + ) -> Tuple[Any, Any]: + """ + The memcached "gats" command, but only for one key, as a convenience. + + Args: + key: str, see class docs for details. + expire: optional int, number of seconds until the item is expired + from the cache, or zero for no expiry (the default). + default: value that will be returned if the key was not found. + cas_default: same behaviour as default argument. + + Returns: + A tuple of (value, cas) + or (default, cas_defaults) if the key was not found. + """ + defaults = (default, cas_default) + return self._fetch_cmd( + b"gats", [key], True, key_prefix=self.key_prefix, expire=expire + ).get(key, defaults) + def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]: """ The memcached "gets" command. @@ -1118,12 +1157,17 @@ class Client: keys: Iterable[Key], expect_cas: bool, key_prefix: bytes = b"", + expire: Optional[int] = None, ) -> Dict[Key, Any]: prefixed_keys = [self.check_key(k, key_prefix=key_prefix) for k in keys] remapped_keys = dict(zip(prefixed_keys, keys)) # It is important for all keys to be listed in their original order. cmd = name + if expire is not None: + expire_bytes = self._check_integer(expire, "expire") + cmd += b" " + expire_bytes + if prefixed_keys: cmd += b" " + b" ".join(prefixed_keys) cmd += b"\r\n" @@ -1498,6 +1542,26 @@ class PooledClient: else: raise + def gat(self, key: Key, expire: int = 0, default: Optional[Any] = None) -> Any: + with self.client_pool.get_and_release(destroy_on_fail=True) as client: + try: + return client.gat(key, expire, default) + except Exception: + if self.ignore_exc: + return default + else: + raise + + def gats(self, key: Key, expire: int = 0, default: Optional[Any] = None) -> Any: + with self.client_pool.get_and_release(destroy_on_fail=True) as client: + try: + return client.gats(key, expire, default) + except Exception: + if self.ignore_exc: + return default + else: + raise + def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index e56517f..aec4090 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -170,18 +170,24 @@ class HashClient: self._last_dead_check_time = current_time def _get_client(self, key): - check_key_helper(key, self.allow_unicode_keys, self.key_prefix) + # If key is tuple use first item as server key + if isinstance(key, tuple) and len(key) == 2: + server_key, key = key + else: + server_key = key + + check_key_helper(server_key, self.allow_unicode_keys, self.key_prefix) if self._dead_clients: self._retry_dead() - server = self.hasher.get_node(key) + server = self.hasher.get_node(server_key) # We've ran out of servers to try if server is None: if self.ignore_exc is True: - return + return None, key raise MemcacheError("All servers seem to be down right now") - return self.clients[server] + return self.clients[server], key def _safely_run_func(self, client, func, default_val, *args, **kwargs): try: @@ -311,7 +317,7 @@ class HashClient: self._failed_clients[server] = failed_metadata def _run_cmd(self, cmd, key, default_val, *args, **kwargs): - client = self._get_client(key) + client, key = self._get_client(key) if client is None: return default_val @@ -346,6 +352,12 @@ class HashClient: def get(self, key, default=None, **kwargs): return self._run_cmd("get", key, default, default=default, **kwargs) + def gat(self, key, default=None, **kwargs): + return self._run_cmd("gat", key, default, default=default, **kwargs) + + def gats(self, key, default=None, **kwargs): + return self._run_cmd("gats", key, default, default=default, **kwargs) + def incr(self, key, *args, **kwargs): return self._run_cmd("incr", key, False, *args, **kwargs) @@ -357,7 +369,7 @@ class HashClient: failed = [] for key, value in values.items(): - client = self._get_client(key) + client, key = self._get_client(key) if client is None: failed.append(key) @@ -378,7 +390,7 @@ class HashClient: end = {} for key in keys: - client = self._get_client(key) + client, key = self._get_client(key) if client is None: continue diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 22f8387..cbf7cf4 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -724,6 +724,21 @@ class TestClient(ClientTestMixin, unittest.TestCase): result = client.gets_many([b"key1", b"key2"]) assert result == {b"key1": (b"value1", b"11")} + def test_gats_not_found(self): + client = self.make_client([b"END\r\n"]) + result = client.gats(b"key") + assert result == (None, None) + + def test_gats_not_found_defaults(self): + client = self.make_client([b"END\r\n"]) + result = client.gats(b"key", default="foo", cas_default="bar") + assert result == ("foo", "bar") + + def test_gats_found(self): + client = self.make_client([b"VALUE key 0 5 10\r\nvalue\r\nEND\r\n"]) + result = client.gats(b"key") + assert result == (b"value", b"10") + def test_touch_not_found(self): client = self.make_client([b"NOT_FOUND\r\n"]) result = client.touch(b"key", noreply=False) @@ -1517,6 +1532,17 @@ class TestPrefixedClient(ClientTestMixin, unittest.TestCase): result = client.get(b"key") assert result == b"value" + def test_gat_found(self): + client = self.make_client( + [ + b"STORED\r\n", + b"VALUE xyz:key 0 5\r\nvalue\r\nEND\r\n", + ] + ) + result = client.set(b"key", b"value", noreply=False) + result = client.gat(b"key") + assert result == b"value" + def test_get_many_some_found(self): client = self.make_client( [ diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py index 4cb066f..99734dd 100644 --- a/pymemcache/test/test_client_hash.py +++ b/pymemcache/test/test_client_hash.py @@ -74,19 +74,50 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): ], ) - def get_clients(key): + def get_node(key): if key == b"key3": - return client.clients["/tmp/pymemcache.1.%d" % pid] + return "/tmp/pymemcache.1.%d" % pid else: - return client.clients["/tmp/pymemcache.2.%d" % pid] + return "/tmp/pymemcache.2.%d" % pid - client._get_client = get_clients + client.hasher.get_node = get_node result = client.set(b"key1", b"value1", noreply=False) result = client.set(b"key3", b"value2", noreply=False) result = client.get_many([b"key1", b"key3"]) assert result == {b"key1": b"value1", b"key3": b"value2"} + def test_get_many_unix_with_server_key(self): + pid = os.getpid() + sockets = [ + "/tmp/pymemcache.1.%d" % pid, + "/tmp/pymemcache.2.%d" % pid, + ] + client = self.make_unix_client( + sockets, + *[ + [ + b"STORED\r\n", + b"STORED\r\n", + b"VALUE key1 0 6\r\nvalue1\r\nVALUE key3 0 6\r\nvalue2\r\nEND\r\n", + ], + [], + ], + ) + + def get_node(key): + if key == b"server_key": + return "/tmp/pymemcache.1.%d" % pid + else: + return "/tmp/pymemcache.2.%d" % pid + + client.hasher.get_node = get_node + + result = client.set((b"server_key", b"key1"), b"value1", noreply=False) + result = client.set((b"server_key", b"key3"), b"value2", noreply=False) + result = client.get_many([(b"server_key", b"key1"), (b"server_key", b"key3")]) + assert result == {b"key1": b"value1", b"key3": b"value2"} + def test_get_many_all_found(self): client = self.make_client( *[ @@ -101,13 +132,13 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): ] ) - def get_clients(key): + def get_node(key): if key == b"key3": - return client.clients["127.0.0.1:11012"] + return "127.0.0.1:11012" else: - return client.clients["127.0.0.1:11013"] + return "127.0.0.1:11013" - client._get_client = get_clients + client.hasher.get_node = get_node result = client.set(b"key1", b"value1", noreply=False) result = client.set(b"key3", b"value2", noreply=False) @@ -127,13 +158,13 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): ] ) - def get_clients(key): + def get_node(key): if key == b"key3": - return client.clients["127.0.0.1:11012"] + return "127.0.0.1:11012" else: - return client.clients["127.0.0.1:11013"] + return "127.0.0.1:11013" - client._get_client = get_clients + client.hasher.get_node = get_node result = client.set(b"key1", b"value1", noreply=False) result = client.get_many([b"key1", b"key3"]) @@ -153,13 +184,13 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): ] ) - def get_clients(key): + def get_node(key): if key == b"key3": - return client.clients["127.0.0.1:11012"] + return "127.0.0.1:11012" else: - return client.clients["127.0.0.1:11013"] + return "127.0.0.1:11013" - client._get_client = get_clients + client.hasher.get_node = get_node with pytest.raises(MemcacheUnknownError): client.set(b"key1", b"value1", noreply=False) @@ -181,13 +212,13 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): ignore_exc=True, ) - def get_clients(key): + def get_node(key): if key == b"key3": - return client.clients["127.0.0.1:11012"] + return "127.0.0.1:11012" else: - return client.clients["127.0.0.1:11013"] + return "127.0.0.1:11013" - client._get_client = get_clients + client.hasher.get_node = get_node client.set(b"key1", b"value1", noreply=False) client.set(b"key3", b"value2", noreply=False) @@ -208,13 +239,13 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): ] ) - def get_clients(key): + def get_node(key): if key == b"key3": - return client.clients["127.0.0.1:11012"] + return "127.0.0.1:11012" else: - return client.clients["127.0.0.1:11013"] + return "127.0.0.1:11013" - client._get_client = get_clients + client.hasher.get_node = get_node assert client.set(b"key1", b"value1", noreply=False) is True assert client.set(b"key3", b"value2", noreply=False) is True @@ -258,7 +289,7 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): ) hashed_client = client._get_client("foo") - assert hashed_client is None + assert hashed_client == (None, "foo") def test_no_servers_left_raise_exception(self): from pymemcache.client.hash import HashClient diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py index 19d04f2..64a8b0a 100644 --- a/pymemcache/test/test_integration.py +++ b/pymemcache/test/test_integration.py @@ -271,6 +271,37 @@ def test_touch(client_class, host, port, socket_module, key_prefix): assert result is True +@pytest.mark.integration() +def test_gat_gats(client_class, host, port, socket_module, key_prefix): + client = client_class( + (host, port), socket_module=socket_module, key_prefix=key_prefix + ) + client.flush_all() + + direct_client = ( + client if hasattr(client, "raw_command") else list(client.clients.values())[0] + ) + + result = client.set(b"key", b"0", 10, noreply=False) + assert result is True + + ttl1 = direct_client.raw_command(b"mg " + key_prefix + b"key t").replace( + b"HD t", b"" + ) + + result = client.gat(b"key", expire=1000) + assert result == b"0" + + result, cas = client.gats(b"key", expire=1000) + assert result == b"0" + + ttl2 = direct_client.raw_command(b"mg " + key_prefix + b"key t").replace( + b"HD t", b"" + ) + + assert int(ttl1) < 950 < int(ttl2) <= 1000 + + @pytest.mark.integration() def test_misc(client_class, host, port, socket_module, key_prefix): client = Client((host, port), socket_module=socket_module, key_prefix=key_prefix) -- cgit v1.2.1