diff options
-rw-r--r-- | pymemcache/serde.py | 55 | ||||
-rw-r--r-- | pymemcache/test/test_benchmark.py | 8 | ||||
-rw-r--r-- | pymemcache/test/test_compression.py | 220 | ||||
-rw-r--r-- | pymemcache/test/test_integration.py | 87 | ||||
-rw-r--r-- | pymemcache/test/test_serde.py | 107 | ||||
-rw-r--r-- | test-requirements.txt | 2 |
6 files changed, 443 insertions, 36 deletions
diff --git a/pymemcache/serde.py b/pymemcache/serde.py index f92aeb7..6e77766 100644 --- a/pymemcache/serde.py +++ b/pymemcache/serde.py @@ -16,12 +16,13 @@ from functools import partial import logging from io import BytesIO import pickle +import zlib FLAG_BYTES = 0 FLAG_PICKLE = 1 << 0 FLAG_INTEGER = 1 << 1 FLAG_LONG = 1 << 2 -FLAG_COMPRESSED = 1 << 3 # unused, to main compatibility with python-memcached +FLAG_COMPRESSED = 1 << 3 FLAG_TEXT = 1 << 4 # Pickle protocol version (highest available to runtime) @@ -121,6 +122,55 @@ class PickleSerde: return python_memcache_deserializer(key, value, flags) +pickle_serde = PickleSerde() + + +class CompressedSerde: + """ + An object which implements the serialization/deserialization protocol for + :py:class:`pymemcache.client.base.Client` and its descendants with + configurable compression. + """ + + def __init__( + self, + compress=zlib.compress, + decompress=zlib.decompress, + serde=pickle_serde, + # Discovered via the `test_optimal_compression_length` test. + min_compress_len=400, + ): + self._serde = serde + self._compress = compress + self._decompress = decompress + self._min_compress_len = min_compress_len + + def serialize(self, key, value): + value, flags = self._serde.serialize(key, value) + + if len(value) > self._min_compress_len > 0: + old_value = value + value = self._compress(value) + # Don't use the compressed value if our end result is actually + # larger uncompressed. + if len(old_value) < len(value): + value = old_value + else: + flags |= FLAG_COMPRESSED + + return value, flags + + def deserialize(self, key, value, flags): + if flags & FLAG_COMPRESSED: + value = self._decompress(value) + + value = self._serde.deserialize(key, value, flags) + return value + + +compressed_serde = CompressedSerde() + + class LegacyWrappingSerde: """ This class defines how to wrap legacy de/serialization functions into a @@ -141,6 +191,3 @@ class LegacyWrappingSerde: def _default_deserialize(self, key, value, flags): return value - - -pickle_serde = PickleSerde() diff --git a/pymemcache/test/test_benchmark.py b/pymemcache/test/test_benchmark.py index 55653bd..f123482 100644 --- a/pymemcache/test/test_benchmark.py +++ b/pymemcache/test/test_benchmark.py @@ -83,14 +83,16 @@ def benchmark(count, func, *args, **kwargs): @pytest.mark.benchmark() def test_bench_get(request, client, pairs, count): - key, value = next(pairs) + key = "pymemcache_test:0" + value = pairs[key] client.set(key, value) benchmark(count, client.get, key) @pytest.mark.benchmark() def test_bench_set(request, client, pairs, count): - key, value = next(pairs.items()) + key = "pymemcache_test:0" + value = pairs[key] benchmark(count, client.set, key, value) @@ -113,4 +115,4 @@ def test_bench_delete(request, client, pairs, count): @pytest.mark.benchmark() def test_bench_delete_multi(request, client, pairs, count): # deleting missing key takes the same work client-side as real keys - benchmark(count, client.delete_multi, list(pairs)) + benchmark(count, client.delete_multi, list(pairs.keys())) diff --git a/pymemcache/test/test_compression.py b/pymemcache/test/test_compression.py new file mode 100644 index 0000000..7ec6641 --- /dev/null +++ b/pymemcache/test/test_compression.py @@ -0,0 +1,220 @@ +from pymemcache.client.base import Client +from pymemcache.serde import ( + CompressedSerde, + pickle_serde, +) + +from faker import Faker + +import pytest +import random +import string +import time +import zstd # type: ignore +import zlib + +fake = Faker(["it_IT", "en_US", "ja_JP"]) + + +def get_random_string(length): + letters = string.ascii_letters + chars = string.punctuation + digits = string.digits + total = letters + chars + digits + result_str = "".join(random.choice(total) for i in range(length)) + return result_str + + +class CustomObject: + """ + Custom class for verifying serialization + """ + + def __init__(self): + self.number = random.randint(0, 100) + self.string = fake.text() + self.object = fake.profile() + + +class CustomObjectValue: + def __init__(self, value): + self.value = value + + +def benchmark(count, func, *args, **kwargs): + start = time.time() + + for _ in range(count): + result = func(*args, **kwargs) + + duration = time.time() - start + print(str(duration)) + + return result + + +@pytest.fixture(scope="session") +def names(): + names = [] + for _ in range(15): + names.append(fake.name()) + + return names + + +@pytest.fixture(scope="session") +def paragraphs(): + paragraphs = [] + for _ in range(15): + paragraphs.append(fake.text()) + + return paragraphs + + +@pytest.fixture(scope="session") +def objects(): + objects = [] + for _ in range(15): + objects.append(CustomObject()) + + return objects + + +# Always run compression for the benchmarks +min_compress_len = 1 + +default_serde = CompressedSerde(min_compress_len=min_compress_len) + +zlib_serde = CompressedSerde( + compress=lambda value: zlib.compress(value, 9), + decompress=lambda value: zlib.decompress(value), + min_compress_len=min_compress_len, +) + +zstd_serde = CompressedSerde( + compress=lambda value: zstd.compress(value), + decompress=lambda value: zstd.decompress(value), + min_compress_len=min_compress_len, +) + +serializers = [ + None, + default_serde, + zlib_serde, + zstd_serde, +] +ids = ["none", "zlib ", "zlib9", "zstd "] + + +@pytest.mark.benchmark() +@pytest.mark.parametrize("serde", serializers, ids=ids) +def test_bench_compress_set_strings(count, host, port, serde, names): + client = Client((host, port), serde=serde, encoding="utf-8") + + def test(): + for index, name in enumerate(names): + key = f"name_{index}" + client.set(key, name) + + benchmark(count, test) + + +@pytest.mark.benchmark() +@pytest.mark.parametrize("serde", serializers, ids=ids) +def test_bench_compress_get_strings(count, host, port, serde, names): + client = Client((host, port), serde=serde, encoding="utf-8") + for index, name in enumerate(names): + key = f"name_{index}" + client.set(key, name) + + def test(): + for index, _ in enumerate(names): + key = f"name_{index}" + client.get(key) + + benchmark(count, test) + + +@pytest.mark.benchmark() +@pytest.mark.parametrize("serde", serializers, ids=ids) +def test_bench_compress_set_large_strings(count, host, port, serde, paragraphs): + client = Client((host, port), serde=serde, encoding="utf-8") + + def test(): + for index, p in enumerate(paragraphs): + key = f"paragraph_{index}" + client.set(key, p) + + benchmark(count, test) + + +@pytest.mark.benchmark() +@pytest.mark.parametrize("serde", serializers, ids=ids) +def test_bench_compress_get_large_strings(count, host, port, serde, paragraphs): + client = Client((host, port), serde=serde, encoding="utf-8") + for index, p in enumerate(paragraphs): + key = f"paragraphs_{index}" + client.set(key, p) + + def test(): + for index, _ in enumerate(paragraphs): + key = f"paragraphs_{index}" + client.get(key) + + benchmark(count, test) + + +@pytest.mark.benchmark() +@pytest.mark.parametrize("serde", serializers, ids=ids) +def test_bench_compress_set_objects(count, host, port, serde, objects): + client = Client((host, port), serde=serde, encoding="utf-8") + + def test(): + for index, o in enumerate(objects): + key = f"objects_{index}" + client.set(key, o) + + benchmark(count, test) + + +@pytest.mark.benchmark() +@pytest.mark.parametrize("serde", serializers, ids=ids) +def test_bench_compress_get_objects(count, host, port, serde, objects): + client = Client((host, port), serde=serde, encoding="utf-8") + for index, o in enumerate(objects): + key = f"objects_{index}" + client.set(key, o) + + def test(): + for index, _ in enumerate(objects): + key = f"objects_{index}" + client.get(key) + + benchmark(count, test) + + +@pytest.mark.benchmark() +def test_optimal_compression_length(): + for length in range(5, 2000): + input_data = get_random_string(length) + start = len(input_data) + + for index, serializer in enumerate(serializers[1:]): + name = ids[index + 1] + value, _ = serializer.serialize("foo", input_data) + end = len(value) + print(f"serializer={name}\t start={start}\t end={end}") + + +@pytest.mark.benchmark() +def test_optimal_compression_length_objects(): + for length in range(5, 2000): + input_data = get_random_string(length) + obj = CustomObjectValue(input_data) + start = len(pickle_serde.serialize("foo", obj)[0]) + + for index, serializer in enumerate(serializers[1:]): + name = ids[index + 1] + value, _ = serializer.serialize("foo", obj) + end = len(value) + print(f"serializer={name}\t start={start}\t end={end}") diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py index 268287f..961beb6 100644 --- a/pymemcache/test/test_integration.py +++ b/pymemcache/test/test_integration.py @@ -17,8 +17,16 @@ import json import pytest from pymemcache.client.base import Client -from pymemcache.exceptions import MemcacheIllegalInputError, MemcacheClientError -from pymemcache.serde import PickleSerde, pickle_serde +from pymemcache.exceptions import ( + MemcacheIllegalInputError, + MemcacheClientError, + MemcacheServerError, +) +from pymemcache.serde import ( + compressed_serde, + PickleSerde, + pickle_serde, +) def get_set_helper(client, key, value, key2, value2): @@ -41,8 +49,15 @@ def get_set_helper(client, key, value, key2, value2): @pytest.mark.integration() -def test_get_set(client_class, host, port, socket_module): - client = client_class((host, port), socket_module=socket_module) +@pytest.mark.parametrize( + "serde", + [ + pickle_serde, + compressed_serde, + ], +) +def test_get_set(client_class, host, port, serde, socket_module): + client = client_class((host, port), serde=serde, socket_module=socket_module) client.flush_all() key = b"key" @@ -53,9 +68,16 @@ def test_get_set(client_class, host, port, socket_module): @pytest.mark.integration() -def test_get_set_unicode_key(client_class, host, port, socket_module): +@pytest.mark.parametrize( + "serde", + [ + pickle_serde, + compressed_serde, + ], +) +def test_get_set_unicode_key(client_class, host, port, serde, socket_module): client = client_class( - (host, port), socket_module=socket_module, allow_unicode_keys=True + (host, port), serde=serde, socket_module=socket_module, allow_unicode_keys=True ) client.flush_all() @@ -67,8 +89,15 @@ def test_get_set_unicode_key(client_class, host, port, socket_module): @pytest.mark.integration() -def test_add_replace(client_class, host, port, socket_module): - client = client_class((host, port), socket_module=socket_module) +@pytest.mark.parametrize( + "serde", + [ + pickle_serde, + compressed_serde, + ], +) +def test_add_replace(client_class, host, port, serde, socket_module): + client = client_class((host, port), serde=serde, socket_module=socket_module) client.flush_all() result = client.add(b"key", b"value", noreply=False) @@ -277,8 +306,15 @@ def serde_serialization_helper(client_class, host, port, socket_module, serde): @pytest.mark.integration() -def test_serde_serialization(client_class, host, port, socket_module): - serde_serialization_helper(client_class, host, port, socket_module, pickle_serde) +@pytest.mark.parametrize( + "serde", + [ + pickle_serde, + compressed_serde, + ], +) +def test_serde_serialization(client_class, host, port, socket_module, serde): + serde_serialization_helper(client_class, host, port, socket_module, serde) @pytest.mark.integration() @@ -350,3 +386,34 @@ def test_tls(client_class, tls_host, tls_port, socket_module, tls_context): key2 = b"key2" value2 = b"value2" get_set_helper(client, key, value, key2, value2) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "serde,should_fail", + [ + (pickle_serde, True), + (compressed_serde, False), + ], +) +def test_get_set_large( + client_class, + host, + port, + serde, + socket_module, + should_fail, +): + client = client_class((host, port), serde=serde, socket_module=socket_module) + client.flush_all() + + key = b"key" + value = b"value" * 1024 * 1024 + key2 = b"key2" + value2 = b"value2" * 1024 * 1024 + + if should_fail: + with pytest.raises(MemcacheServerError): + get_set_helper(client, key, value, key2, value2) + else: + get_set_helper(client, key, value, key2, value2) diff --git a/pymemcache/test/test_serde.py b/pymemcache/test/test_serde.py index e36e213..cb92271 100644 --- a/pymemcache/test/test_serde.py +++ b/pymemcache/test/test_serde.py @@ -1,15 +1,19 @@ from unittest import TestCase from pymemcache.serde import ( + CompressedSerde, pickle_serde, PickleSerde, FLAG_BYTES, + FLAG_COMPRESSED, FLAG_PICKLE, FLAG_INTEGER, FLAG_TEXT, ) import pytest import pickle +import sys +import zlib class CustomInt(int): @@ -23,39 +27,40 @@ class CustomInt(int): pass -@pytest.mark.unit() -class TestSerde(TestCase): - serde = pickle_serde +def check(serde, value, expected_flags): + serialized, flags = serde.serialize(b"key", value) + assert flags == expected_flags + + # pymemcache stores values as byte strings, so we immediately the value + # if needed so deserialized works as it would with a real server + if not isinstance(serialized, bytes): + serialized = str(serialized).encode("ascii") - def check(self, value, expected_flags): - serialized, flags = self.serde.serialize(b"key", value) - assert flags == expected_flags + deserialized = serde.deserialize(b"key", serialized, flags) + assert deserialized == value - # pymemcache stores values as byte strings, so we immediately the value - # if needed so deserialized works as it would with a real server - if not isinstance(serialized, bytes): - serialized = str(serialized).encode("ascii") - deserialized = self.serde.deserialize(b"key", serialized, flags) - assert deserialized == value +@pytest.mark.unit() +class TestSerde: + serde = pickle_serde def test_bytes(self): - self.check(b"value", FLAG_BYTES) - self.check(b"\xc2\xa3 $ \xe2\x82\xac", FLAG_BYTES) # £ $ € + check(self.serde, b"value", FLAG_BYTES) + check(self.serde, b"\xc2\xa3 $ \xe2\x82\xac", FLAG_BYTES) # £ $ € def test_unicode(self): - self.check("value", FLAG_TEXT) - self.check("£ $ €", FLAG_TEXT) + check(self.serde, "value", FLAG_TEXT) + check(self.serde, "£ $ €", FLAG_TEXT) def test_int(self): - self.check(1, FLAG_INTEGER) + check(self.serde, 1, FLAG_INTEGER) def test_pickleable(self): - self.check({"a": "dict"}, FLAG_PICKLE) + check(self.serde, {"a": "dict"}, FLAG_PICKLE) def test_subtype(self): # Subclass of a native type will be restored as the same type - self.check(CustomInt(123123), FLAG_PICKLE) + check(self.serde, CustomInt(123123), FLAG_PICKLE) @pytest.mark.unit() @@ -76,3 +81,67 @@ class TestSerdePickleVersion2(TestCase): @pytest.mark.unit() class TestSerdePickleVersionHighest(TestCase): serde = PickleSerde(pickle_version=pickle.HIGHEST_PROTOCOL) + + +@pytest.mark.parametrize("serde", [pickle_serde, CompressedSerde()]) +@pytest.mark.unit() +def test_compressed_simple(serde): + # test_bytes + check(serde, b"value", FLAG_BYTES) + check(serde, b"\xc2\xa3 $ \xe2\x82\xac", FLAG_BYTES) # £ $ € + + # test_unicode + check(serde, "value", FLAG_TEXT) + check(serde, "£ $ €", FLAG_TEXT) + + # test_int + check(serde, 1, FLAG_INTEGER) + + # test_pickleable + check(serde, {"a": "dict"}, FLAG_PICKLE) + + # test_subtype + # Subclass of a native type will be restored as the same type + check(serde, CustomInt(12312), FLAG_PICKLE) + + +@pytest.mark.parametrize( + "serde", + [ + CompressedSerde(min_compress_len=49), + # Custom compression. This could be something like lz4 + CompressedSerde( + compress=lambda value: zlib.compress(value, 9), + decompress=lambda value: zlib.decompress(value), + min_compress_len=49, + ), + ], +) +@pytest.mark.unit() +def test_compressed_complex(serde): + # test_bytes + check(serde, b"value" * 10, FLAG_BYTES | FLAG_COMPRESSED) + check(serde, b"\xc2\xa3 $ \xe2\x82\xac" * 10, FLAG_BYTES | FLAG_COMPRESSED) # £ $ € + + # test_unicode + check(serde, "value" * 10, FLAG_TEXT | FLAG_COMPRESSED) + check(serde, "£ $ €" * 10, FLAG_TEXT | FLAG_COMPRESSED) + + # test_int, doesn't make sense to compress + check(serde, sys.maxsize, FLAG_INTEGER) + + # test_pickleable + check( + serde, + { + "foo": "bar", + "baz": "qux", + "uno": "dos", + "tres": "tres", + }, + FLAG_PICKLE | FLAG_COMPRESSED, + ) + + # test_subtype + # Subclass of a native type will be restored as the same type + check(serde, CustomInt(sys.maxsize), FLAG_PICKLE | FLAG_COMPRESSED) diff --git a/test-requirements.txt b/test-requirements.txt index ac08524..a707c65 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,5 +1,7 @@ +Faker==13.15.0 pytest==7.1.1 pytest-cov==3.0.0 gevent==21.12.0; "PyPy" not in platform_python_implementation pylibmc==1.6.1; sys.platform != 'win32' python-memcached==1.59 +zstd==1.5.2.5 |