summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Gordon <jogo@pinterest.com>2022-08-17 10:01:18 -0700
committerJoe Gordon <jogo@pinterest.com>2022-09-12 09:49:43 -0700
commitaba7f3ffd0fe76f0a6068c05133b5bb2fef94909 (patch)
tree2d88b451d1782046b21ab62241c550b0f26c6057
parent6b85dea74b6545a1a0dc2f1f22bc84ec7c4ea8cd (diff)
downloadpymemcache-aba7f3ffd0fe76f0a6068c05133b5bb2fef94909.tar.gz
Start to add type hints
First pass at adding some type hints to pymemcache to make it easier to develop against etc.
-rw-r--r--pymemcache/client/base.py178
-rw-r--r--pymemcache/client/hash.py2
-rw-r--r--pymemcache/serde.py10
-rw-r--r--pyproject.toml1
4 files changed, 126 insertions, 65 deletions
diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py
index 3d4234c..ef6bcac 100644
--- a/pymemcache/client/base.py
+++ b/pymemcache/client/base.py
@@ -16,20 +16,18 @@ import errno
from functools import partial
import platform
import socket
-from typing import Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, Callable, Iterable
from pymemcache import pool
-
-from pymemcache.serde import LegacyWrappingSerde
from pymemcache.exceptions import (
MemcacheClientError,
- MemcacheUnknownCommandError,
MemcacheIllegalInputError,
MemcacheServerError,
- MemcacheUnknownError,
MemcacheUnexpectedCloseError,
+ MemcacheUnknownCommandError,
+ MemcacheUnknownError,
)
-
+from pymemcache.serde import LegacyWrappingSerde
RECV_SIZE = 4096
VALID_STORE_RESULTS = {
@@ -53,23 +51,24 @@ STORE_RESULTS_VALUE = {
}
ServerSpec = Union[Tuple[str, int], str]
+Key = Union[bytes, str]
# Some of the values returned by the "stats" command
# need mapping into native Python types
-def _parse_bool_int(value):
+def _parse_bool_int(value: bytes) -> bool:
return int(value) != 0
-def _parse_bool_string_is_yes(value):
+def _parse_bool_string_is_yes(value: bytes) -> bool:
return value == b"yes"
-def _parse_float(value):
+def _parse_float(value: bytes) -> float:
return float(value.replace(b":", b"."))
-def _parse_hex(value):
+def _parse_hex(value: bytes) -> int:
return int(value, 8)
@@ -96,7 +95,9 @@ STAT_TYPES = {
# Common helper functions.
-def check_key_helper(key, allow_unicode_keys, key_prefix=b""):
+def check_key_helper(
+ key: Key, allow_unicode_keys: bool, key_prefix: bytes = b""
+) -> bytes:
"""Checks key and add key_prefix."""
if allow_unicode_keys:
if isinstance(key, str):
@@ -160,7 +161,7 @@ class KeepaliveOpts:
__slots__ = ("idle", "intvl", "cnt")
- def __init__(self, idle=1, intvl=1, cnt=5):
+ def __init__(self, idle: int = 1, intvl: int = 1, cnt: int = 5) -> None:
if idle < 1:
raise ValueError("The idle parameter must be greater or equal to 1.")
self.idle = idle
@@ -275,15 +276,15 @@ class Client:
serializer=None,
deserializer=None,
connect_timeout=None,
- timeout=None,
- no_delay=False,
- ignore_exc=False,
+ timeout: Optional[float] = None,
+ no_delay: bool = False,
+ ignore_exc: bool = False,
socket_module=socket,
- socket_keepalive=None,
- key_prefix=b"",
+ socket_keepalive: Optional[KeepaliveOpts] = None,
+ key_prefix: bytes = b"",
default_noreply=True,
- allow_unicode_keys=False,
- encoding="ascii",
+ allow_unicode_keys: bool = False,
+ encoding: str = "ascii",
tls_context=None,
):
"""
@@ -354,7 +355,7 @@ class Client:
"KeepaliveOpts object. That's the only supported type "
"of structure."
)
- self.sock = None
+ self.sock: Optional[socket.socket] = None
if isinstance(key_prefix, str):
key_prefix = key_prefix.encode("ascii")
if not isinstance(key_prefix, bytes):
@@ -365,13 +366,13 @@ class Client:
self.encoding = encoding
self.tls_context = tls_context
- def check_key(self, key):
+ def check_key(self, key: Key) -> bytes:
"""Checks key and add key_prefix."""
return check_key_helper(
key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix
)
- def _connect(self):
+ def _connect(self) -> None:
self.close()
s = self.socket_module
@@ -426,7 +427,7 @@ class Client:
self.sock = sock
- def close(self):
+ def close(self) -> None:
"""Close the connection to memcached, if it is open. The next call to a
method that requires a connection will re-open it."""
if self.sock is not None:
@@ -439,7 +440,14 @@ class Client:
disconnect_all = close
- def set(self, key, value, expire=0, noreply=None, flags=None):
+ def set(
+ self,
+ key: Key,
+ value: Any,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ) -> Optional[bool]:
"""
The memcached "set" command.
@@ -460,9 +468,17 @@ class Client:
"""
if noreply is None:
noreply = self.default_noreply
+ # Optional because _store_cmd lookup in STORE_RESULTS_VALUE can return None in some cases.
+ # TODO: refactor to fix
return self._store_cmd(b"set", {key: value}, expire, noreply, flags=flags)[key]
- def set_many(self, values, expire=0, noreply=None, flags=None):
+ def set_many(
+ self,
+ values: Dict[Key, Any],
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ) -> List[Key]:
"""
A convenience function for setting multiple values.
@@ -487,7 +503,14 @@ class Client:
set_multi = set_many
- def add(self, key, value, expire=0, noreply=None, flags=None):
+ def add(
+ self,
+ key: Key,
+ value: Any,
+ expire: int = 0,
+ noreply: Optional[bool] = None,
+ flags: Optional[int] = None,
+ ):
"""
The memcached "add" command.
@@ -608,7 +631,7 @@ class Client:
b"cas", {key: value}, expire, noreply, flags=flags, cas=cas
)[key]
- def get(self, key, default=None):
+ def get(self, key: Key, default: Optional[Any] = None):
"""
The memcached "get" command, but only for one key, as a convenience.
@@ -673,7 +696,7 @@ class Client:
return self._fetch_cmd(b"gets", keys, True)
- def delete(self, key, noreply=None):
+ def delete(self, key: Key, noreply=None):
"""
The memcached "delete" command.
@@ -698,7 +721,7 @@ class Client:
return True
return results[0] == b"DELETED"
- def delete_many(self, keys, noreply=None):
+ def delete_many(self, keys: Iterable[Key], noreply: Optional[bool] = None) -> bool:
"""
A convenience function to delete multiple keys.
@@ -732,7 +755,9 @@ class Client:
delete_multi = delete_many
- def incr(self, key, value, noreply=False):
+ def incr(
+ self, key: Key, value: int, noreply: Optional[bool] = False
+ ) -> Optional[int]:
"""
The memcached "incr" command.
@@ -746,8 +771,8 @@ class Client:
value of the key, or None if the key wasn't found.
"""
key = self.check_key(key)
- value = self._check_integer(value, "value")
- cmd = b"incr " + key + b" " + value
+ val = self._check_integer(value, "value")
+ cmd = b"incr " + key + b" " + val
if noreply:
cmd += b" noreply"
cmd += b"\r\n"
@@ -758,7 +783,9 @@ class Client:
return None
return int(results[0])
- def decr(self, key, value, noreply=False):
+ def decr(
+ self, key: Key, value: int, noreply: Optional[bool] = False
+ ) -> Optional[int]:
"""
The memcached "decr" command.
@@ -772,8 +799,8 @@ class Client:
value of the key, or None if the key wasn't found.
"""
key = self.check_key(key)
- value = self._check_integer(value, "value")
- cmd = b"decr " + key + b" " + value
+ val = self._check_integer(value, "value")
+ cmd = b"decr " + key + b" " + val
if noreply:
cmd += b" noreply"
cmd += b"\r\n"
@@ -784,7 +811,7 @@ class Client:
return None
return int(results[0])
- def touch(self, key, expire=0, noreply=None):
+ def touch(self, key: Key, expire: int = 0, noreply: Optional[bool] = None) -> bool:
"""
The memcached "touch" command.
@@ -802,8 +829,8 @@ class Client:
if noreply is None:
noreply = self.default_noreply
key = self.check_key(key)
- expire = self._check_integer(expire, "expire")
- cmd = b"touch " + key + b" " + expire
+ expire_bytes = self._check_integer(expire, "expire")
+ cmd = b"touch " + key + b" " + expire_bytes
if noreply:
cmd += b" noreply"
cmd += b"\r\n"
@@ -914,7 +941,7 @@ class Client:
return True
return results[0] == b"OK"
- def quit(self):
+ def quit(self) -> None:
"""
The memcached "quit" command.
@@ -964,7 +991,7 @@ class Client:
error = line[line.find(b" ") + 1 :]
raise MemcacheServerError(error)
- def _check_integer(self, value, name):
+ def _check_integer(self, value: int, name: str) -> bytes:
"""Check that a value is an integer and encode it as a binary string"""
if not isinstance(value, int):
raise MemcacheIllegalInputError(
@@ -973,7 +1000,7 @@ class Client:
return str(value).encode(self.encoding)
- def _check_cas(self, cas):
+ def _check_cas(self, cas: Union[int, str, bytes]) -> bytes:
"""Check that a value is a valid input for 'cas' -- either an int or a
string containing only 0-9
@@ -997,7 +1024,14 @@ class Client:
return cas
- def _extract_value(self, expect_cas, line, buf, remapped_keys, prefixed_keys):
+ def _extract_value(
+ self,
+ expect_cas: bool,
+ line: bytes,
+ buf: bytes,
+ remapped_keys,
+ prefixed_keys: List[bytes],
+ ):
"""
This function is abstracted from _fetch_cmd to support different ways
of value extraction. In order to use this feature, _extract_value needs
@@ -1009,7 +1043,7 @@ class Client:
try:
_, key, flags, size = line.split()
except Exception as e:
- raise ValueError(f"Unable to parse line {line}: {e}")
+ raise ValueError(f"Unable to parse line {line!r}: {e}")
value = None
try:
@@ -1025,7 +1059,7 @@ class Client:
else:
return key, value, buf
- def _fetch_cmd(self, name, keys, expect_cas):
+ def _fetch_cmd(self, name: bytes, keys: Iterable[Key], expect_cas: bool):
prefixed_keys = [self.check_key(k) for k in keys]
remapped_keys = dict(zip(prefixed_keys, keys))
@@ -1039,11 +1073,14 @@ class Client:
if self.sock is None:
self._connect()
+ # For typing
+ assert self.sock is not None
+
self.sock.sendall(cmd)
buf = b""
line = None
- result = {}
+ result: Dict[bytes, bytes] = {}
while True:
try:
buf, line = _readline(self.sock, buf)
@@ -1073,7 +1110,15 @@ class Client:
return {}
raise
- def _store_cmd(self, name, values, expire, noreply, flags=None, cas=None):
+ def _store_cmd(
+ self,
+ name: bytes,
+ values: Dict[Key, Any],
+ expire: int,
+ noreply: bool,
+ flags: Optional[int] = None,
+ cas: Optional[bytes] = None,
+ ) -> Dict[Key, Optional[bool]]:
cmds = []
keys = []
@@ -1082,7 +1127,7 @@ class Client:
extra += b" " + cas
if noreply:
extra += b" noreply"
- expire = self._check_integer(expire, "expire")
+ expire_bytes = self._check_integer(expire, "expire")
for key, data in values.items():
# must be able to reliably map responses back to the original order
@@ -1111,7 +1156,7 @@ class Client:
+ b" "
+ str(data_flags).encode(self.encoding)
+ b" "
- + expire
+ + expire_bytes
+ b" "
+ str(len(data)).encode(self.encoding)
+ extra
@@ -1123,6 +1168,9 @@ class Client:
if self.sock is None:
self._connect()
+ # For typing
+ assert self.sock is not None
+
try:
self.sock.sendall(b"".join(cmds))
if noreply:
@@ -1148,10 +1196,17 @@ class Client:
self.close()
raise
- def _misc_cmd(self, cmds, cmd_name, noreply, end_tokens=None):
+ def _misc_cmd(
+ self,
+ cmds: Iterable[bytes],
+ cmd_name: bytes,
+ noreply: Optional[bool],
+ end_tokens=None,
+ ) -> List[bytes]:
# If no end_tokens have been given, just assume standard memcached
# operations, which end in "\r\n", use regular code for that.
+ _reader: Callable[[socket.socket, bytes], Tuple[bytes, bytes]]
if end_tokens:
_reader = partial(_readsegment, end_tokens=end_tokens)
else:
@@ -1160,6 +1215,9 @@ class Client:
if self.sock is None:
self._connect()
+ # For typing
+ assert self.sock is not None
+
try:
self.sock.sendall(b"".join(cmds))
@@ -1236,7 +1294,7 @@ class PooledClient:
max_pool_size=None,
pool_idle_timeout=0,
lock_generator=None,
- default_noreply=True,
+ default_noreply: bool = True,
allow_unicode_keys=False,
encoding="ascii",
tls_context=None,
@@ -1266,7 +1324,7 @@ class PooledClient:
self.encoding = encoding
self.tls_context = tls_context
- def check_key(self, key):
+ def check_key(self, key: Key) -> bytes:
"""Checks key and add key_prefix."""
return check_key_helper(
key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix
@@ -1443,7 +1501,7 @@ class PooledClient:
self.delete(key, noreply=True)
-def _readline(sock, buf):
+def _readline(sock: socket.socket, buf: bytes) -> Tuple[bytes, bytes]:
"""Read line of text from the socket.
Read a line of text (delimited by "\r\n") from the socket, and
@@ -1452,18 +1510,18 @@ def _readline(sock, buf):
Args:
sock: Socket object, should be connected.
- buf: String, zero or more characters, returned from an earlier
- call to _readline or _readvalue (pass an empty string on the
+ buf: Bytes, zero or more characters, returned from an earlier
+ call to _readline or _readvalue (pass an empty byte string on the
first call).
Returns:
A tuple of (buf, line) where line is the full line read from the
socket (minus the "\r\n" characters) and buf is any trailing
characters read after the "\r\n" was found (which may be an empty
- string).
+ byte string).
"""
- chunks = []
+ chunks: List[bytes] = []
last_char = b""
while True:
@@ -1494,7 +1552,7 @@ def _readline(sock, buf):
raise MemcacheUnexpectedCloseError()
-def _readvalue(sock, buf, size):
+def _readvalue(sock, buf, size: int):
"""Read specified amount of bytes from the socket.
Read size bytes, followed by the "\r\n" characters, from the socket,
@@ -1539,7 +1597,9 @@ def _readvalue(sock, buf, size):
return buf[rlen:], b"".join(chunks)
-def _readsegment(sock, buf, end_tokens):
+def _readsegment(
+ sock: socket.socket, buf: bytes, end_tokens: bytes
+) -> Tuple[bytes, bytes]:
"""Read a segment from the socket.
Read a segment from the socket, up to the first end_token sub-string/bytes,
@@ -1575,7 +1635,7 @@ def _readsegment(sock, buf, end_tokens):
raise MemcacheUnexpectedCloseError()
-def _recv(sock, size):
+def _recv(sock: socket.socket, size: int) -> bytes:
"""sock.recv() with retry on EINTR"""
while True:
try:
diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py
index db44108..57b07db 100644
--- a/pymemcache/client/hash.py
+++ b/pymemcache/client/hash.py
@@ -154,7 +154,7 @@ class HashClient:
self._dead_clients[server] = dead_time
self.hasher.remove_node(key)
- def _retry_dead(self):
+ def _retry_dead(self) -> None:
current_time = time.time()
ldc = self._last_dead_check_time
# We have reached the retry timeout
diff --git a/pymemcache/serde.py b/pymemcache/serde.py
index 6e77766..42ec922 100644
--- a/pymemcache/serde.py
+++ b/pymemcache/serde.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from functools import partial
import logging
-from io import BytesIO
import pickle
import zlib
+from functools import partial
+from io import BytesIO
FLAG_BYTES = 0
FLAG_PICKLE = 1 << 0
@@ -61,7 +61,7 @@ def _python_memcache_serializer(key, value, pickle_version=None):
return value, flags
-def get_python_memcache_serializer(pickle_version=DEFAULT_PICKLE_VERSION):
+def get_python_memcache_serializer(pickle_version: int = DEFAULT_PICKLE_VERSION):
"""Return a serializer using a specific pickle version"""
return partial(_python_memcache_serializer, pickle_version=pickle_version)
@@ -112,7 +112,7 @@ class PickleSerde:
for :py:class:`pymemcache.client.base.Client`
"""
- def __init__(self, pickle_version=DEFAULT_PICKLE_VERSION):
+ def __init__(self, pickle_version: int = DEFAULT_PICKLE_VERSION) -> None:
self._serialize_func = get_python_memcache_serializer(pickle_version)
def serialize(self, key, value):
@@ -182,7 +182,7 @@ class LegacyWrappingSerde:
case that they are missing.
"""
- def __init__(self, serializer_func, deserializer_func):
+ def __init__(self, serializer_func, deserializer_func) -> None:
self.serialize = serializer_func or self._default_serialize
self.deserialize = deserializer_func or self._default_deserialize
diff --git a/pyproject.toml b/pyproject.toml
index 8d26def..d08806c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,3 +3,4 @@ target-version = ['py37', 'py38', 'py39', 'py310']
[tool.mypy]
python_version = 3.7
+ignore_missing_imports = true