diff options
author | Joe Gordon <jogo@users.noreply.github.com> | 2022-10-03 14:19:02 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-03 14:19:02 -0700 |
commit | ab8bf324d19798cb2f741b53c67f6ad823aca89f (patch) | |
tree | 1e3b87a015837c611889975d2c742a1ec9bcdfd7 /pymemcache/client/base.py | |
parent | 91787fd84fac5a39b4cdacf2f84d667817bfa401 (diff) | |
parent | 7c9557435b321331362df862034d5676fd65a8ae (diff) | |
download | pymemcache-ab8bf324d19798cb2f741b53c67f6ad823aca89f.tar.gz |
Merge pull request #426 from jogo/typing
Add more type annotations
Diffstat (limited to 'pymemcache/client/base.py')
-rw-r--r-- | pymemcache/client/base.py | 232 |
1 files changed, 166 insertions, 66 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index ef6bcac..8791d5b 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -13,10 +13,12 @@ # limitations under the License. import errno -from functools import partial import platform import socket -from typing import Any, Dict, List, Optional, Tuple, Union, Callable, Iterable +from functools import partial +from ssl import SSLContext +from types import ModuleType +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from pymemcache import pool from pymemcache.exceptions import ( @@ -72,7 +74,7 @@ def _parse_hex(value: bytes) -> int: return int(value, 8) -STAT_TYPES = { +STAT_TYPES: Dict[bytes, Callable[[bytes], Any]] = { # General stats b"version": bytes, b"rusage_user": _parse_float, @@ -275,17 +277,17 @@ class Client: serde=None, serializer=None, deserializer=None, - connect_timeout=None, + connect_timeout: Optional[float] = None, timeout: Optional[float] = None, no_delay: bool = False, ignore_exc: bool = False, - socket_module=socket, + socket_module: ModuleType = socket, socket_keepalive: Optional[KeepaliveOpts] = None, key_prefix: bytes = b"", - default_noreply=True, + default_noreply: bool = True, allow_unicode_keys: bool = False, encoding: str = "ascii", - tls_context=None, + tls_context: Optional[SSLContext] = None, ): """ Constructor. @@ -510,7 +512,7 @@ class Client: expire: int = 0, noreply: Optional[bool] = None, flags: Optional[int] = None, - ): + ) -> bool: """ The memcached "add" command. @@ -532,9 +534,21 @@ class Client: """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b"add", {key: value}, expire, noreply, flags=flags)[key] + response = self._store_cmd(b"add", {key: value}, expire, noreply, flags=flags)[ + key + ] + # For typing, can only be None, if cas command + assert response is not None + return response - def replace(self, key, value, expire=0, noreply=None, flags=None): + def replace( + self, + key: Key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ) -> bool: """ The memcached "replace" command. @@ -556,11 +570,21 @@ class Client: """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b"replace", {key: value}, expire, noreply, flags=flags)[ - key - ] + response = self._store_cmd( + b"replace", {key: value}, expire, noreply, flags=flags + )[key] + # for typing + assert response is not None + return response - def append(self, key, value, expire=0, noreply=None, flags=None): + def append( + self, + key: Key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ) -> bool: """ The memcached "append" command. @@ -579,11 +603,21 @@ class Client: """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b"append", {key: value}, expire, noreply, flags=flags)[ - key - ] + response = self._store_cmd( + b"append", {key: value}, expire, noreply, flags=flags + )[key] + # For typing + assert response is not None + return response - def prepend(self, key, value, expire=0, noreply=None, flags=None): + def prepend( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): """ The memcached "prepend" command. @@ -606,7 +640,15 @@ class Client: key ] - def cas(self, key, value, cas, expire=0, noreply=False, flags=None): + def cas( + self, + key, + value, + cas, + expire: int = 0, + noreply=False, + flags: Optional[int] = None, + ) -> Optional[bool]: """ The memcached "cas" command. @@ -631,7 +673,7 @@ class Client: b"cas", {key: value}, expire, noreply, flags=flags, cas=cas )[key] - def get(self, key: Key, default: Optional[Any] = None): + def get(self, key: Key, default: Optional[Any] = None) -> Any: """ The memcached "get" command, but only for one key, as a convenience. @@ -644,7 +686,7 @@ class Client: """ return self._fetch_cmd(b"get", [key], False).get(key, default) - def get_many(self, keys): + def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]: """ The memcached "get" command. @@ -663,7 +705,9 @@ class Client: get_multi = get_many - def gets(self, key, default=None, cas_default=None): + def gets( + self, key: Key, default: Any = None, cas_default: Any = None + ) -> Tuple[Any, Any]: """ The memcached "gets" command for one key, as a convenience. @@ -679,7 +723,7 @@ class Client: defaults = (default, cas_default) return self._fetch_cmd(b"gets", [key], True).get(key, defaults) - def gets_many(self, keys): + def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]: """ The memcached "gets" command. @@ -696,7 +740,7 @@ class Client: return self._fetch_cmd(b"gets", keys, True) - def delete(self, key: Key, noreply=None): + def delete(self, key: Key, noreply: Optional[bool] = None) -> bool: """ The memcached "delete" command. @@ -865,7 +909,7 @@ class Client: return result - def cache_memlimit(self, memlimit): + def cache_memlimit(self, memlimit) -> bool: """ The memcached "cache_memlimit" command. @@ -880,7 +924,7 @@ class Client: self._fetch_cmd(b"cache_memlimit", [memlimit], False) return True - def version(self): + def version(self) -> bytes: """ The memcached "version" command. @@ -892,10 +936,12 @@ class Client: before, _, after = results[0].partition(b" ") if before != b"VERSION": - raise MemcacheUnknownError("Received unexpected response: %s" % results[0]) + raise MemcacheUnknownError(f"Received unexpected response: {results[0]!r}") return after - def raw_command(self, command, end_tokens="\r\n"): + def raw_command( + self, command: Union[str, bytes], end_tokens: Union[str, bytes] = "\r\n" + ) -> bytes: """ Sends an arbitrary command to the server and parses the response until a specified token is encountered. @@ -916,7 +962,7 @@ class Client: ) return self._misc_cmd([b"" + command + b"\r\n"], command, False, end_tokens)[0] - def flush_all(self, delay=0, noreply=None): + def flush_all(self, delay: int = 0, noreply: Optional[bool] = None) -> bool: """ The memcached "flush_all" command. @@ -931,8 +977,8 @@ class Client: """ if noreply is None: noreply = self.default_noreply - delay = self._check_integer(delay, "delay") - cmd = b"flush_all " + delay + delay_bytes = self._check_integer(delay, "delay") + cmd = b"flush_all " + delay_bytes if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -953,7 +999,7 @@ class Client: self._misc_cmd([cmd], b"quit", True) self.close() - def shutdown(self, graceful=False): + def shutdown(self, graceful: bool = False) -> None: """ The memcached "shutdown" command. @@ -979,7 +1025,7 @@ class Client: except MemcacheUnexpectedCloseError: pass - def _raise_errors(self, line, name): + def _raise_errors(self, line: bytes, name: bytes) -> None: if line.startswith(b"ERROR"): raise MemcacheUnknownCommandError(name) @@ -1029,9 +1075,9 @@ class Client: expect_cas: bool, line: bytes, buf: bytes, - remapped_keys, + remapped_keys: Dict[bytes, Key], prefixed_keys: List[bytes], - ): + ) -> Tuple[Key, Union[Any, Tuple[Any, bytes]], bytes]: """ This function is abstracted from _fetch_cmd to support different ways of value extraction. In order to use this feature, _extract_value needs @@ -1047,19 +1093,24 @@ class Client: value = None try: + # For typing + assert self.sock is not None + buf, value = _readvalue(self.sock, buf, int(size)) except MemcacheUnexpectedCloseError: self.close() raise - key = remapped_keys[key] - value = self.serde.deserialize(key, value, int(flags)) + original_key = remapped_keys[key] + value = self.serde.deserialize(original_key, value, int(flags)) if expect_cas: - return key, (value, cas), buf + return original_key, (value, cas), buf else: - return key, value, buf + return original_key, value, buf - def _fetch_cmd(self, name: bytes, keys: Iterable[Key], expect_cas: bool): + def _fetch_cmd( + self, name: bytes, keys: Iterable[Key], expect_cas: bool + ) -> Dict[Key, Any]: prefixed_keys = [self.check_key(k) for k in keys] remapped_keys = dict(zip(prefixed_keys, keys)) @@ -1080,7 +1131,7 @@ class Client: buf = b"" line = None - result: Dict[bytes, bytes] = {} + result: Dict[Key, Any] = {} while True: try: buf, line = _readline(self.sock, buf) @@ -1241,7 +1292,7 @@ class Client: self.close() raise - def __setitem__(self, key, value): + def __setitem__(self, key: Key, value): self.set(key, value, noreply=True) def __getitem__(self, key): @@ -1330,7 +1381,7 @@ class PooledClient: key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix ) - def _create_client(self): + def _create_client(self) -> Client: return self.client_class( self.server, serde=self.serde, @@ -1348,46 +1399,88 @@ class PooledClient: tls_context=self.tls_context, ) - def close(self): + def close(self) -> None: self.client_pool.clear() disconnect_all = close - def set(self, key, value, expire=0, noreply=None, flags=None): + def set( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.set(key, value, expire=expire, noreply=noreply, flags=flags) - def set_many(self, values, expire=0, noreply=None, flags=None): + def set_many( + self, + values, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.set_many(values, expire=expire, noreply=noreply, flags=flags) set_multi = set_many - def replace(self, key, value, expire=0, noreply=None, flags=None): + def replace( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.replace( key, value, expire=expire, noreply=noreply, flags=flags ) - def append(self, key, value, expire=0, noreply=None, flags=None): + def append( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.append( key, value, expire=expire, noreply=noreply, flags=flags ) - def prepend(self, key, value, expire=0, noreply=None, flags=None): + def prepend( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.prepend( key, value, expire=expire, noreply=noreply, flags=flags ) - def cas(self, key, value, cas, expire=0, noreply=False, flags=None): + def cas( + self, + key, + value, + cas, + expire: int = 0, + noreply=False, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.cas( key, value, cas, expire=expire, noreply=noreply, flags=flags ) - def get(self, key, default=None): + def get(self, key: Key, default: Any = None) -> Any: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: return client.get(key, default) @@ -1397,7 +1490,7 @@ class PooledClient: else: raise - def get_many(self, keys): + def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: return client.get_many(keys) @@ -1409,7 +1502,7 @@ class PooledClient: get_multi = get_many - def gets(self, key): + def gets(self, key: Key) -> Tuple[Any, Any]: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: return client.gets(key) @@ -1419,7 +1512,7 @@ class PooledClient: else: raise - def gets_many(self, keys): + def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: return client.gets_many(keys) @@ -1429,29 +1522,36 @@ class PooledClient: else: raise - def delete(self, key, noreply=None): + def delete(self, key: Key, noreply: Optional[bool] = None) -> bool: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.delete(key, noreply=noreply) - def delete_many(self, keys, noreply=None): + def delete_many(self, keys: Iterable[Key], noreply: Optional[bool] = None) -> bool: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.delete_many(keys, noreply=noreply) delete_multi = delete_many - def add(self, key, value, expire=0, noreply=None, flags=None): + def add( + self, + key: Key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.add(key, value, expire=expire, noreply=noreply, flags=flags) - def incr(self, key, value, noreply=False): + def incr(self, key: Key, value, noreply=False): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.incr(key, value, noreply=noreply) - def decr(self, key, value, noreply=False): + def decr(self, key: Key, value, noreply=False): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.decr(key, value, noreply=noreply) - def touch(self, key, expire=0, noreply=None): + def touch(self, key: Key, expire: int = 0, noreply=None): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.touch(key, expire=expire, noreply=noreply) @@ -1465,22 +1565,22 @@ class PooledClient: else: raise - def version(self): + def version(self) -> bytes: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.version() - def flush_all(self, delay=0, noreply=None): + def flush_all(self, delay=0, noreply=None) -> bool: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.flush_all(delay=delay, noreply=noreply) - def quit(self): + def quit(self) -> None: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: client.quit() finally: self.client_pool.destroy(client) - def shutdown(self, graceful=False): + def shutdown(self, graceful: bool = False) -> None: with self.client_pool.get_and_release(destroy_on_fail=True) as client: client.shutdown(graceful) @@ -1488,7 +1588,7 @@ class PooledClient: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.raw_command(command, end_tokens) - def __setitem__(self, key, value): + def __setitem__(self, key: Key, value): self.set(key, value, noreply=True) def __getitem__(self, key): @@ -1552,7 +1652,7 @@ def _readline(sock: socket.socket, buf: bytes) -> Tuple[bytes, bytes]: raise MemcacheUnexpectedCloseError() -def _readvalue(sock, buf, size: int): +def _readvalue(sock: socket.socket, buf: bytes, size: int): """Read specified amount of bytes from the socket. Read size bytes, followed by the "\r\n" characters, from the socket, |