summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Gordon <jogo@users.noreply.github.com>2022-10-03 14:19:02 -0700
committerGitHub <noreply@github.com>2022-10-03 14:19:02 -0700
commitab8bf324d19798cb2f741b53c67f6ad823aca89f (patch)
tree1e3b87a015837c611889975d2c742a1ec9bcdfd7
parent91787fd84fac5a39b4cdacf2f84d667817bfa401 (diff)
parent7c9557435b321331362df862034d5676fd65a8ae (diff)
downloadpymemcache-ab8bf324d19798cb2f741b53c67f6ad823aca89f.tar.gz
Merge pull request #426 from jogo/typing
Add more type annotations
-rw-r--r--pymemcache/client/base.py232
-rw-r--r--pymemcache/client/hash.py10
-rw-r--r--pymemcache/pool.py39
3 files changed, 193 insertions, 88 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py
index ef6bcac..8791d5b 100644
--- a/pymemcache/client/base.py
+++ b/pymemcache/client/base.py
@@ -13,10 +13,12 @@
# limitations under the License.
import errno
-from functools import partial
import platform
import socket
-from typing import Any, Dict, List, Optional, Tuple, Union, Callable, Iterable
+from functools import partial
+from ssl import SSLContext
+from types import ModuleType
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from pymemcache import pool
from pymemcache.exceptions import (
@@ -72,7 +74,7 @@ def _parse_hex(value: bytes) -> int:
return int(value, 8)
-STAT_TYPES = {
+STAT_TYPES: Dict[bytes, Callable[[bytes], Any]] = {
# General stats
b"version": bytes,
b"rusage_user": _parse_float,
@@ -275,17 +277,17 @@ class Client:
serde=None,
serializer=None,
deserializer=None,
- connect_timeout=None,
+ connect_timeout: Optional[float] = None,
timeout: Optional[float] = None,
no_delay: bool = False,
ignore_exc: bool = False,
- socket_module=socket,
+ socket_module: ModuleType = socket,
socket_keepalive: Optional[KeepaliveOpts] = None,
key_prefix: bytes = b"",
- default_noreply=True,
+ default_noreply: bool = True,
allow_unicode_keys: bool = False,
encoding: str = "ascii",
- tls_context=None,
+ tls_context: Optional[SSLContext] = None,
):
"""
Constructor.
@@ -510,7 +512,7 @@ class Client:
expire: int = 0,
noreply: Optional[bool] = None,
flags: Optional[int] = None,
- ):
+ ) -> bool:
"""
The memcached "add" command.
@@ -532,9 +534,21 @@ class Client:
"""
if noreply is None:
noreply = self.default_noreply
- return self._store_cmd(b"add", {key: value}, expire, noreply, flags=flags)[key]
+ response = self._store_cmd(b"add", {key: value}, expire, noreply, flags=flags)[
+ key
+ ]
+ # For typing, can only be None, if cas command
+ assert response is not None
+ return response
- def replace(self, key, value, expire=0, noreply=None, flags=None):
+ def replace(
+ self,
+ key: Key,
+ value,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ) -> bool:
"""
The memcached "replace" command.
@@ -556,11 +570,21 @@ class Client:
"""
if noreply is None:
noreply = self.default_noreply
- return self._store_cmd(b"replace", {key: value}, expire, noreply, flags=flags)[
- key
- ]
+ response = self._store_cmd(
+ b"replace", {key: value}, expire, noreply, flags=flags
+ )[key]
+ # for typing
+ assert response is not None
+ return response
- def append(self, key, value, expire=0, noreply=None, flags=None):
+ def append(
+ self,
+ key: Key,
+ value,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ) -> bool:
"""
The memcached "append" command.
@@ -579,11 +603,21 @@ class Client:
"""
if noreply is None:
noreply = self.default_noreply
- return self._store_cmd(b"append", {key: value}, expire, noreply, flags=flags)[
- key
- ]
+ response = self._store_cmd(
+ b"append", {key: value}, expire, noreply, flags=flags
+ )[key]
+ # For typing
+ assert response is not None
+ return response
- def prepend(self, key, value, expire=0, noreply=None, flags=None):
+ def prepend(
+ self,
+ key,
+ value,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ):
"""
The memcached "prepend" command.
@@ -606,7 +640,15 @@ class Client:
key
]
- def cas(self, key, value, cas, expire=0, noreply=False, flags=None):
+ def cas(
+ self,
+ key,
+ value,
+ cas,
+ expire: int = 0,
+ noreply=False,
+ flags: Optional[int] = None,
+ ) -> Optional[bool]:
"""
The memcached "cas" command.
@@ -631,7 +673,7 @@ class Client:
b"cas", {key: value}, expire, noreply, flags=flags, cas=cas
)[key]
- def get(self, key: Key, default: Optional[Any] = None):
+ def get(self, key: Key, default: Optional[Any] = None) -> Any:
"""
The memcached "get" command, but only for one key, as a convenience.
@@ -644,7 +686,7 @@ class Client:
"""
return self._fetch_cmd(b"get", [key], False).get(key, default)
- def get_many(self, keys):
+ def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]:
"""
The memcached "get" command.
@@ -663,7 +705,9 @@ class Client:
get_multi = get_many
- def gets(self, key, default=None, cas_default=None):
+ def gets(
+ self, key: Key, default: Any = None, cas_default: Any = None
+ ) -> Tuple[Any, Any]:
"""
The memcached "gets" command for one key, as a convenience.
@@ -679,7 +723,7 @@ class Client:
defaults = (default, cas_default)
return self._fetch_cmd(b"gets", [key], True).get(key, defaults)
- def gets_many(self, keys):
+ def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]:
"""
The memcached "gets" command.
@@ -696,7 +740,7 @@ class Client:
return self._fetch_cmd(b"gets", keys, True)
- def delete(self, key: Key, noreply=None):
+ def delete(self, key: Key, noreply: Optional[bool] = None) -> bool:
"""
The memcached "delete" command.
@@ -865,7 +909,7 @@ class Client:
return result
- def cache_memlimit(self, memlimit):
+ def cache_memlimit(self, memlimit) -> bool:
"""
The memcached "cache_memlimit" command.
@@ -880,7 +924,7 @@ class Client:
self._fetch_cmd(b"cache_memlimit", [memlimit], False)
return True
- def version(self):
+ def version(self) -> bytes:
"""
The memcached "version" command.
@@ -892,10 +936,12 @@ class Client:
before, _, after = results[0].partition(b" ")
if before != b"VERSION":
- raise MemcacheUnknownError("Received unexpected response: %s" % results[0])
+ raise MemcacheUnknownError(f"Received unexpected response: {results[0]!r}")
return after
- def raw_command(self, command, end_tokens="\r\n"):
+ def raw_command(
+ self, command: Union[str, bytes], end_tokens: Union[str, bytes] = "\r\n"
+ ) -> bytes:
"""
Sends an arbitrary command to the server and parses the response until a
specified token is encountered.
@@ -916,7 +962,7 @@ class Client:
)
return self._misc_cmd([b"" + command + b"\r\n"], command, False, end_tokens)[0]
- def flush_all(self, delay=0, noreply=None):
+ def flush_all(self, delay: int = 0, noreply: Optional[bool] = None) -> bool:
"""
The memcached "flush_all" command.
@@ -931,8 +977,8 @@ class Client:
"""
if noreply is None:
noreply = self.default_noreply
- delay = self._check_integer(delay, "delay")
- cmd = b"flush_all " + delay
+ delay_bytes = self._check_integer(delay, "delay")
+ cmd = b"flush_all " + delay_bytes
if noreply:
cmd += b" noreply"
cmd += b"\r\n"
@@ -953,7 +999,7 @@ class Client:
self._misc_cmd([cmd], b"quit", True)
self.close()
- def shutdown(self, graceful=False):
+ def shutdown(self, graceful: bool = False) -> None:
"""
The memcached "shutdown" command.
@@ -979,7 +1025,7 @@ class Client:
except MemcacheUnexpectedCloseError:
pass
- def _raise_errors(self, line, name):
+ def _raise_errors(self, line: bytes, name: bytes) -> None:
if line.startswith(b"ERROR"):
raise MemcacheUnknownCommandError(name)
@@ -1029,9 +1075,9 @@ class Client:
expect_cas: bool,
line: bytes,
buf: bytes,
- remapped_keys,
+ remapped_keys: Dict[bytes, Key],
prefixed_keys: List[bytes],
- ):
+ ) -> Tuple[Key, Union[Any, Tuple[Any, bytes]], bytes]:
"""
This function is abstracted from _fetch_cmd to support different ways
of value extraction. In order to use this feature, _extract_value needs
@@ -1047,19 +1093,24 @@ class Client:
value = None
try:
+ # For typing
+ assert self.sock is not None
+
buf, value = _readvalue(self.sock, buf, int(size))
except MemcacheUnexpectedCloseError:
self.close()
raise
- key = remapped_keys[key]
- value = self.serde.deserialize(key, value, int(flags))
+ original_key = remapped_keys[key]
+ value = self.serde.deserialize(original_key, value, int(flags))
if expect_cas:
- return key, (value, cas), buf
+ return original_key, (value, cas), buf
else:
- return key, value, buf
+ return original_key, value, buf
- def _fetch_cmd(self, name: bytes, keys: Iterable[Key], expect_cas: bool):
+ def _fetch_cmd(
+ self, name: bytes, keys: Iterable[Key], expect_cas: bool
+ ) -> Dict[Key, Any]:
prefixed_keys = [self.check_key(k) for k in keys]
remapped_keys = dict(zip(prefixed_keys, keys))
@@ -1080,7 +1131,7 @@ class Client:
buf = b""
line = None
- result: Dict[bytes, bytes] = {}
+ result: Dict[Key, Any] = {}
while True:
try:
buf, line = _readline(self.sock, buf)
@@ -1241,7 +1292,7 @@ class Client:
self.close()
raise
- def __setitem__(self, key, value):
+ def __setitem__(self, key: Key, value):
self.set(key, value, noreply=True)
def __getitem__(self, key):
@@ -1330,7 +1381,7 @@ class PooledClient:
key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix
)
- def _create_client(self):
+ def _create_client(self) -> Client:
return self.client_class(
self.server,
serde=self.serde,
@@ -1348,46 +1399,88 @@ class PooledClient:
tls_context=self.tls_context,
)
- def close(self):
+ def close(self) -> None:
self.client_pool.clear()
disconnect_all = close
- def set(self, key, value, expire=0, noreply=None, flags=None):
+ def set(
+ self,
+ key,
+ value,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.set(key, value, expire=expire, noreply=noreply, flags=flags)
- def set_many(self, values, expire=0, noreply=None, flags=None):
+ def set_many(
+ self,
+ values,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.set_many(values, expire=expire, noreply=noreply, flags=flags)
set_multi = set_many
- def replace(self, key, value, expire=0, noreply=None, flags=None):
+ def replace(
+ self,
+ key,
+ value,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.replace(
key, value, expire=expire, noreply=noreply, flags=flags
)
- def append(self, key, value, expire=0, noreply=None, flags=None):
+ def append(
+ self,
+ key,
+ value,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.append(
key, value, expire=expire, noreply=noreply, flags=flags
)
- def prepend(self, key, value, expire=0, noreply=None, flags=None):
+ def prepend(
+ self,
+ key,
+ value,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.prepend(
key, value, expire=expire, noreply=noreply, flags=flags
)
- def cas(self, key, value, cas, expire=0, noreply=False, flags=None):
+ def cas(
+ self,
+ key,
+ value,
+ cas,
+ expire: int = 0,
+ noreply=False,
+ flags: Optional[int] = None,
+ ):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.cas(
key, value, cas, expire=expire, noreply=noreply, flags=flags
)
- def get(self, key, default=None):
+ def get(self, key: Key, default: Any = None) -> Any:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
try:
return client.get(key, default)
@@ -1397,7 +1490,7 @@ class PooledClient:
else:
raise
- def get_many(self, keys):
+ def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
try:
return client.get_many(keys)
@@ -1409,7 +1502,7 @@ class PooledClient:
get_multi = get_many
- def gets(self, key):
+ def gets(self, key: Key) -> Tuple[Any, Any]:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
try:
return client.gets(key)
@@ -1419,7 +1512,7 @@ class PooledClient:
else:
raise
- def gets_many(self, keys):
+ def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
try:
return client.gets_many(keys)
@@ -1429,29 +1522,36 @@ class PooledClient:
else:
raise
- def delete(self, key, noreply=None):
+ def delete(self, key: Key, noreply: Optional[bool] = None) -> bool:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.delete(key, noreply=noreply)
- def delete_many(self, keys, noreply=None):
+ def delete_many(self, keys: Iterable[Key], noreply: Optional[bool] = None) -> bool:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.delete_many(keys, noreply=noreply)
delete_multi = delete_many
- def add(self, key, value, expire=0, noreply=None, flags=None):
+ def add(
+ self,
+ key: Key,
+ value,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.add(key, value, expire=expire, noreply=noreply, flags=flags)
- def incr(self, key, value, noreply=False):
+ def incr(self, key: Key, value, noreply=False):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.incr(key, value, noreply=noreply)
- def decr(self, key, value, noreply=False):
+ def decr(self, key: Key, value, noreply=False):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.decr(key, value, noreply=noreply)
- def touch(self, key, expire=0, noreply=None):
+ def touch(self, key: Key, expire: int = 0, noreply=None):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.touch(key, expire=expire, noreply=noreply)
@@ -1465,22 +1565,22 @@ class PooledClient:
else:
raise
- def version(self):
+ def version(self) -> bytes:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.version()
- def flush_all(self, delay=0, noreply=None):
+ def flush_all(self, delay=0, noreply=None) -> bool:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.flush_all(delay=delay, noreply=noreply)
- def quit(self):
+ def quit(self) -> None:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
try:
client.quit()
finally:
self.client_pool.destroy(client)
- def shutdown(self, graceful=False):
+ def shutdown(self, graceful: bool = False) -> None:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
client.shutdown(graceful)
@@ -1488,7 +1588,7 @@ class PooledClient:
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.raw_command(command, end_tokens)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: Key, value):
self.set(key, value, noreply=True)
def __getitem__(self, key):
@@ -1552,7 +1652,7 @@ def _readline(sock: socket.socket, buf: bytes) -> Tuple[bytes, bytes]:
raise MemcacheUnexpectedCloseError()
-def _readvalue(sock, buf, size: int):
+def _readvalue(sock: socket.socket, buf: bytes, size: int):
"""Read specified amount of bytes from the socket.
Read size bytes, followed by the "\r\n" characters, from the socket,
diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py
index 57b07db..e56517f 100644
--- a/pymemcache/client/hash.py
+++ b/pymemcache/client/hash.py
@@ -123,7 +123,7 @@ class HashClient:
return "%s:%s" % server
return server
- def add_server(self, server, port=None):
+ def add_server(self, server, port=None) -> None:
# To maintain backward compatibility, if a port is provided, assume
# that server wasn't provided as a (host, port) tuple.
if port is not None:
@@ -140,7 +140,7 @@ class HashClient:
self.clients[key] = client
self.hasher.add_node(key)
- def remove_server(self, server, port=None):
+ def remove_server(self, server, port=None) -> None:
# To maintain backward compatibility, if a port is provided, assume
# that server wasn't provided as a (host, port) tuple.
if port is not None:
@@ -422,7 +422,7 @@ class HashClient:
def delete(self, key, *args, **kwargs):
return self._run_cmd("delete", key, False, *args, **kwargs)
- def delete_many(self, keys, *args, **kwargs):
+ def delete_many(self, keys, *args, **kwargs) -> bool:
for key in keys:
self._run_cmd("delete", key, False, *args, **kwargs)
return True
@@ -438,10 +438,10 @@ class HashClient:
def touch(self, key, *args, **kwargs):
return self._run_cmd("touch", key, False, *args, **kwargs)
- def flush_all(self, *args, **kwargs):
+ def flush_all(self, *args, **kwargs) -> None:
for client in self.clients.values():
self._safely_run_func(client, client.flush_all, False, *args, **kwargs)
- def quit(self):
+ def quit(self) -> None:
for client in self.clients.values():
self._safely_run_func(client, client.quit, False)
diff --git a/pymemcache/pool.py b/pymemcache/pool.py
index 5c100f8..382abd9 100644
--- a/pymemcache/pool.py
+++ b/pymemcache/pool.py
@@ -14,24 +14,27 @@
import collections
import contextlib
-import sys
import threading
import time
+from typing import Callable, Optional, TypeVar, Deque, List, Generic, Iterator
-class ObjectPool:
+T = TypeVar("T")
+
+
+class ObjectPool(Generic[T]):
"""A pool of objects that release/creates/destroys as needed."""
def __init__(
self,
- obj_creator,
- after_remove=None,
- max_size=None,
- idle_timeout=0,
- lock_generator=None,
+ obj_creator: Callable[[], T],
+ after_remove: Optional[Callable] = None,
+ max_size: Optional[int] = None,
+ idle_timeout: int = 0,
+ lock_generator: Optional[Callable] = None,
):
- self._used_objs = collections.deque()
- self._free_objs = collections.deque()
+ self._used_objs: Deque[T] = collections.deque()
+ self._free_objs: Deque[T] = collections.deque()
self._obj_creator = obj_creator
if lock_generator is None:
self._lock = threading.Lock()
@@ -43,7 +46,10 @@ class ObjectPool:
raise ValueError('"max_size" must be a positive integer')
self.max_size = max_size
self.idle_timeout = idle_timeout
- self._idle_clock = time.time if idle_timeout else int
+ if idle_timeout:
+ self._idle_clock = time.time
+ else:
+ self._idle_clock = float
@property
def used(self):
@@ -54,17 +60,16 @@ class ObjectPool:
return tuple(self._free_objs)
@contextlib.contextmanager
- def get_and_release(self, destroy_on_fail=False):
+ def get_and_release(self, destroy_on_fail=False) -> Iterator[T]:
obj = self.get()
try:
yield obj
except Exception:
- exc_info = sys.exc_info()
if not destroy_on_fail:
self.release(obj)
else:
self.destroy(obj)
- raise exc_info[1].with_traceback(exc_info[2])
+ raise
self.release(obj)
def get(self):
@@ -91,7 +96,7 @@ class ObjectPool:
obj._last_used = now
return obj
- def destroy(self, obj, silent=True):
+ def destroy(self, obj, silent=True) -> None:
was_dropped = False
with self._lock:
try:
@@ -103,7 +108,7 @@ class ObjectPool:
if was_dropped and self._after_remove is not None:
self._after_remove(obj)
- def release(self, obj, silent=True):
+ def release(self, obj, silent=True) -> None:
with self._lock:
try:
self._used_objs.remove(obj)
@@ -113,9 +118,9 @@ class ObjectPool:
if not silent:
raise
- def clear(self):
+ def clear(self) -> None:
if self._after_remove is not None:
- needs_destroy = []
+ needs_destroy: List[T] = []
with self._lock:
needs_destroy.extend(self._used_objs)
needs_destroy.extend(self._free_objs)