diff options
author | Joe Gordon <jogo@pinterest.com> | 2022-09-30 10:40:29 -0700 |
---|---|---|
committer | Joe Gordon <jogo@pinterest.com> | 2022-10-03 13:09:09 -0700 |
commit | 7c9557435b321331362df862034d5676fd65a8ae (patch) | |
tree | cf9fa2814e18db8e42745d957f78e9add625ba3e | |
parent | 4940034c01e8440bf18e4b86619631779c7624c7 (diff) | |
download | pymemcache-7c9557435b321331362df862034d5676fd65a8ae.tar.gz |
Add more type annotations
Continue to improve the type annotation coverage
-rw-r--r-- | pymemcache/client/base.py | 232 | ||||
-rw-r--r-- | pymemcache/client/hash.py | 10 | ||||
-rw-r--r-- | pymemcache/pool.py | 39 |
3 files changed, 193 insertions, 88 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index ef6bcac..8791d5b 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -13,10 +13,12 @@ # limitations under the License. import errno -from functools import partial import platform import socket -from typing import Any, Dict, List, Optional, Tuple, Union, Callable, Iterable +from functools import partial +from ssl import SSLContext +from types import ModuleType +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from pymemcache import pool from pymemcache.exceptions import ( @@ -72,7 +74,7 @@ def _parse_hex(value: bytes) -> int: return int(value, 8) -STAT_TYPES = { +STAT_TYPES: Dict[bytes, Callable[[bytes], Any]] = { # General stats b"version": bytes, b"rusage_user": _parse_float, @@ -275,17 +277,17 @@ class Client: serde=None, serializer=None, deserializer=None, - connect_timeout=None, + connect_timeout: Optional[float] = None, timeout: Optional[float] = None, no_delay: bool = False, ignore_exc: bool = False, - socket_module=socket, + socket_module: ModuleType = socket, socket_keepalive: Optional[KeepaliveOpts] = None, key_prefix: bytes = b"", - default_noreply=True, + default_noreply: bool = True, allow_unicode_keys: bool = False, encoding: str = "ascii", - tls_context=None, + tls_context: Optional[SSLContext] = None, ): """ Constructor. @@ -510,7 +512,7 @@ class Client: expire: int = 0, noreply: Optional[bool] = None, flags: Optional[int] = None, - ): + ) -> bool: """ The memcached "add" command. @@ -532,9 +534,21 @@ class Client: """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b"add", {key: value}, expire, noreply, flags=flags)[key] + response = self._store_cmd(b"add", {key: value}, expire, noreply, flags=flags)[ + key + ] + # For typing, can only be None, if cas command + assert response is not None + return response - def replace(self, key, value, expire=0, noreply=None, flags=None): + def replace( + self, + key: Key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ) -> bool: """ The memcached "replace" command. @@ -556,11 +570,21 @@ class Client: """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b"replace", {key: value}, expire, noreply, flags=flags)[ - key - ] + response = self._store_cmd( + b"replace", {key: value}, expire, noreply, flags=flags + )[key] + # for typing + assert response is not None + return response - def append(self, key, value, expire=0, noreply=None, flags=None): + def append( + self, + key: Key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ) -> bool: """ The memcached "append" command. @@ -579,11 +603,21 @@ class Client: """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b"append", {key: value}, expire, noreply, flags=flags)[ - key - ] + response = self._store_cmd( + b"append", {key: value}, expire, noreply, flags=flags + )[key] + # For typing + assert response is not None + return response - def prepend(self, key, value, expire=0, noreply=None, flags=None): + def prepend( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): """ The memcached "prepend" command. @@ -606,7 +640,15 @@ class Client: key ] - def cas(self, key, value, cas, expire=0, noreply=False, flags=None): + def cas( + self, + key, + value, + cas, + expire: int = 0, + noreply=False, + flags: Optional[int] = None, + ) -> Optional[bool]: """ The memcached "cas" command. @@ -631,7 +673,7 @@ class Client: b"cas", {key: value}, expire, noreply, flags=flags, cas=cas )[key] - def get(self, key: Key, default: Optional[Any] = None): + def get(self, key: Key, default: Optional[Any] = None) -> Any: """ The memcached "get" command, but only for one key, as a convenience. @@ -644,7 +686,7 @@ class Client: """ return self._fetch_cmd(b"get", [key], False).get(key, default) - def get_many(self, keys): + def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]: """ The memcached "get" command. @@ -663,7 +705,9 @@ class Client: get_multi = get_many - def gets(self, key, default=None, cas_default=None): + def gets( + self, key: Key, default: Any = None, cas_default: Any = None + ) -> Tuple[Any, Any]: """ The memcached "gets" command for one key, as a convenience. @@ -679,7 +723,7 @@ class Client: defaults = (default, cas_default) return self._fetch_cmd(b"gets", [key], True).get(key, defaults) - def gets_many(self, keys): + def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]: """ The memcached "gets" command. @@ -696,7 +740,7 @@ class Client: return self._fetch_cmd(b"gets", keys, True) - def delete(self, key: Key, noreply=None): + def delete(self, key: Key, noreply: Optional[bool] = None) -> bool: """ The memcached "delete" command. @@ -865,7 +909,7 @@ class Client: return result - def cache_memlimit(self, memlimit): + def cache_memlimit(self, memlimit) -> bool: """ The memcached "cache_memlimit" command. @@ -880,7 +924,7 @@ class Client: self._fetch_cmd(b"cache_memlimit", [memlimit], False) return True - def version(self): + def version(self) -> bytes: """ The memcached "version" command. @@ -892,10 +936,12 @@ class Client: before, _, after = results[0].partition(b" ") if before != b"VERSION": - raise MemcacheUnknownError("Received unexpected response: %s" % results[0]) + raise MemcacheUnknownError(f"Received unexpected response: {results[0]!r}") return after - def raw_command(self, command, end_tokens="\r\n"): + def raw_command( + self, command: Union[str, bytes], end_tokens: Union[str, bytes] = "\r\n" + ) -> bytes: """ Sends an arbitrary command to the server and parses the response until a specified token is encountered. @@ -916,7 +962,7 @@ class Client: ) return self._misc_cmd([b"" + command + b"\r\n"], command, False, end_tokens)[0] - def flush_all(self, delay=0, noreply=None): + def flush_all(self, delay: int = 0, noreply: Optional[bool] = None) -> bool: """ The memcached "flush_all" command. @@ -931,8 +977,8 @@ class Client: """ if noreply is None: noreply = self.default_noreply - delay = self._check_integer(delay, "delay") - cmd = b"flush_all " + delay + delay_bytes = self._check_integer(delay, "delay") + cmd = b"flush_all " + delay_bytes if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -953,7 +999,7 @@ class Client: self._misc_cmd([cmd], b"quit", True) self.close() - def shutdown(self, graceful=False): + def shutdown(self, graceful: bool = False) -> None: """ The memcached "shutdown" command. @@ -979,7 +1025,7 @@ class Client: except MemcacheUnexpectedCloseError: pass - def _raise_errors(self, line, name): + def _raise_errors(self, line: bytes, name: bytes) -> None: if line.startswith(b"ERROR"): raise MemcacheUnknownCommandError(name) @@ -1029,9 +1075,9 @@ class Client: expect_cas: bool, line: bytes, buf: bytes, - remapped_keys, + remapped_keys: Dict[bytes, Key], prefixed_keys: List[bytes], - ): + ) -> Tuple[Key, Union[Any, Tuple[Any, bytes]], bytes]: """ This function is abstracted from _fetch_cmd to support different ways of value extraction. In order to use this feature, _extract_value needs @@ -1047,19 +1093,24 @@ class Client: value = None try: + # For typing + assert self.sock is not None + buf, value = _readvalue(self.sock, buf, int(size)) except MemcacheUnexpectedCloseError: self.close() raise - key = remapped_keys[key] - value = self.serde.deserialize(key, value, int(flags)) + original_key = remapped_keys[key] + value = self.serde.deserialize(original_key, value, int(flags)) if expect_cas: - return key, (value, cas), buf + return original_key, (value, cas), buf else: - return key, value, buf + return original_key, value, buf - def _fetch_cmd(self, name: bytes, keys: Iterable[Key], expect_cas: bool): + def _fetch_cmd( + self, name: bytes, keys: Iterable[Key], expect_cas: bool + ) -> Dict[Key, Any]: prefixed_keys = [self.check_key(k) for k in keys] remapped_keys = dict(zip(prefixed_keys, keys)) @@ -1080,7 +1131,7 @@ class Client: buf = b"" line = None - result: Dict[bytes, bytes] = {} + result: Dict[Key, Any] = {} while True: try: buf, line = _readline(self.sock, buf) @@ -1241,7 +1292,7 @@ class Client: self.close() raise - def __setitem__(self, key, value): + def __setitem__(self, key: Key, value): self.set(key, value, noreply=True) def __getitem__(self, key): @@ -1330,7 +1381,7 @@ class PooledClient: key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix ) - def _create_client(self): + def _create_client(self) -> Client: return self.client_class( self.server, serde=self.serde, @@ -1348,46 +1399,88 @@ class PooledClient: tls_context=self.tls_context, ) - def close(self): + def close(self) -> None: self.client_pool.clear() disconnect_all = close - def set(self, key, value, expire=0, noreply=None, flags=None): + def set( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.set(key, value, expire=expire, noreply=noreply, flags=flags) - def set_many(self, values, expire=0, noreply=None, flags=None): + def set_many( + self, + values, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.set_many(values, expire=expire, noreply=noreply, flags=flags) set_multi = set_many - def replace(self, key, value, expire=0, noreply=None, flags=None): + def replace( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.replace( key, value, expire=expire, noreply=noreply, flags=flags ) - def append(self, key, value, expire=0, noreply=None, flags=None): + def append( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.append( key, value, expire=expire, noreply=noreply, flags=flags ) - def prepend(self, key, value, expire=0, noreply=None, flags=None): + def prepend( + self, + key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.prepend( key, value, expire=expire, noreply=noreply, flags=flags ) - def cas(self, key, value, cas, expire=0, noreply=False, flags=None): + def cas( + self, + key, + value, + cas, + expire: int = 0, + noreply=False, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.cas( key, value, cas, expire=expire, noreply=noreply, flags=flags ) - def get(self, key, default=None): + def get(self, key: Key, default: Any = None) -> Any: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: return client.get(key, default) @@ -1397,7 +1490,7 @@ class PooledClient: else: raise - def get_many(self, keys): + def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: return client.get_many(keys) @@ -1409,7 +1502,7 @@ class PooledClient: get_multi = get_many - def gets(self, key): + def gets(self, key: Key) -> Tuple[Any, Any]: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: return client.gets(key) @@ -1419,7 +1512,7 @@ class PooledClient: else: raise - def gets_many(self, keys): + def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: return client.gets_many(keys) @@ -1429,29 +1522,36 @@ class PooledClient: else: raise - def delete(self, key, noreply=None): + def delete(self, key: Key, noreply: Optional[bool] = None) -> bool: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.delete(key, noreply=noreply) - def delete_many(self, keys, noreply=None): + def delete_many(self, keys: Iterable[Key], noreply: Optional[bool] = None) -> bool: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.delete_many(keys, noreply=noreply) delete_multi = delete_many - def add(self, key, value, expire=0, noreply=None, flags=None): + def add( + self, + key: Key, + value, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.add(key, value, expire=expire, noreply=noreply, flags=flags) - def incr(self, key, value, noreply=False): + def incr(self, key: Key, value, noreply=False): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.incr(key, value, noreply=noreply) - def decr(self, key, value, noreply=False): + def decr(self, key: Key, value, noreply=False): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.decr(key, value, noreply=noreply) - def touch(self, key, expire=0, noreply=None): + def touch(self, key: Key, expire: int = 0, noreply=None): with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.touch(key, expire=expire, noreply=noreply) @@ -1465,22 +1565,22 @@ class PooledClient: else: raise - def version(self): + def version(self) -> bytes: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.version() - def flush_all(self, delay=0, noreply=None): + def flush_all(self, delay=0, noreply=None) -> bool: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.flush_all(delay=delay, noreply=noreply) - def quit(self): + def quit(self) -> None: with self.client_pool.get_and_release(destroy_on_fail=True) as client: try: client.quit() finally: self.client_pool.destroy(client) - def shutdown(self, graceful=False): + def shutdown(self, graceful: bool = False) -> None: with self.client_pool.get_and_release(destroy_on_fail=True) as client: client.shutdown(graceful) @@ -1488,7 +1588,7 @@ class PooledClient: with self.client_pool.get_and_release(destroy_on_fail=True) as client: return client.raw_command(command, end_tokens) - def __setitem__(self, key, value): + def __setitem__(self, key: Key, value): self.set(key, value, noreply=True) def __getitem__(self, key): @@ -1552,7 +1652,7 @@ def _readline(sock: socket.socket, buf: bytes) -> Tuple[bytes, bytes]: raise MemcacheUnexpectedCloseError() -def _readvalue(sock, buf, size: int): +def _readvalue(sock: socket.socket, buf: bytes, size: int): """Read specified amount of bytes from the socket. Read size bytes, followed by the "\r\n" characters, from the socket, diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index 57b07db..e56517f 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -123,7 +123,7 @@ class HashClient: return "%s:%s" % server return server - def add_server(self, server, port=None): + def add_server(self, server, port=None) -> None: # To maintain backward compatibility, if a port is provided, assume # that server wasn't provided as a (host, port) tuple. if port is not None: @@ -140,7 +140,7 @@ class HashClient: self.clients[key] = client self.hasher.add_node(key) - def remove_server(self, server, port=None): + def remove_server(self, server, port=None) -> None: # To maintain backward compatibility, if a port is provided, assume # that server wasn't provided as a (host, port) tuple. if port is not None: @@ -422,7 +422,7 @@ class HashClient: def delete(self, key, *args, **kwargs): return self._run_cmd("delete", key, False, *args, **kwargs) - def delete_many(self, keys, *args, **kwargs): + def delete_many(self, keys, *args, **kwargs) -> bool: for key in keys: self._run_cmd("delete", key, False, *args, **kwargs) return True @@ -438,10 +438,10 @@ class HashClient: def touch(self, key, *args, **kwargs): return self._run_cmd("touch", key, False, *args, **kwargs) - def flush_all(self, *args, **kwargs): + def flush_all(self, *args, **kwargs) -> None: for client in self.clients.values(): self._safely_run_func(client, client.flush_all, False, *args, **kwargs) - def quit(self): + def quit(self) -> None: for client in self.clients.values(): self._safely_run_func(client, client.quit, False) diff --git a/pymemcache/pool.py b/pymemcache/pool.py index 5c100f8..382abd9 100644 --- a/pymemcache/pool.py +++ b/pymemcache/pool.py @@ -14,24 +14,27 @@ import collections import contextlib -import sys import threading import time +from typing import Callable, Optional, TypeVar, Deque, List, Generic, Iterator -class ObjectPool: +T = TypeVar("T") + + +class ObjectPool(Generic[T]): """A pool of objects that release/creates/destroys as needed.""" def __init__( self, - obj_creator, - after_remove=None, - max_size=None, - idle_timeout=0, - lock_generator=None, + obj_creator: Callable[[], T], + after_remove: Optional[Callable] = None, + max_size: Optional[int] = None, + idle_timeout: int = 0, + lock_generator: Optional[Callable] = None, ): - self._used_objs = collections.deque() - self._free_objs = collections.deque() + self._used_objs: Deque[T] = collections.deque() + self._free_objs: Deque[T] = collections.deque() self._obj_creator = obj_creator if lock_generator is None: self._lock = threading.Lock() @@ -43,7 +46,10 @@ class ObjectPool: raise ValueError('"max_size" must be a positive integer') self.max_size = max_size self.idle_timeout = idle_timeout - self._idle_clock = time.time if idle_timeout else int + if idle_timeout: + self._idle_clock = time.time + else: + self._idle_clock = float @property def used(self): @@ -54,17 +60,16 @@ class ObjectPool: return tuple(self._free_objs) @contextlib.contextmanager - def get_and_release(self, destroy_on_fail=False): + def get_and_release(self, destroy_on_fail=False) -> Iterator[T]: obj = self.get() try: yield obj except Exception: - exc_info = sys.exc_info() if not destroy_on_fail: self.release(obj) else: self.destroy(obj) - raise exc_info[1].with_traceback(exc_info[2]) + raise self.release(obj) def get(self): @@ -91,7 +96,7 @@ class ObjectPool: obj._last_used = now return obj - def destroy(self, obj, silent=True): + def destroy(self, obj, silent=True) -> None: was_dropped = False with self._lock: try: @@ -103,7 +108,7 @@ class ObjectPool: if was_dropped and self._after_remove is not None: self._after_remove(obj) - def release(self, obj, silent=True): + def release(self, obj, silent=True) -> None: with self._lock: try: self._used_objs.remove(obj) @@ -113,9 +118,9 @@ class ObjectPool: if not silent: raise - def clear(self): + def clear(self) -> None: if self._after_remove is not None: - needs_destroy = [] + needs_destroy: List[T] = [] with self._lock: needs_destroy.extend(self._used_objs) needs_destroy.extend(self._free_objs) |