From 6fe50f5985b01c1f9b172f740c51400540bfc92d Mon Sep 17 00:00:00 2001 From: Serhii Tereshchenko Date: Thu, 2 Mar 2023 12:18:21 +0200 Subject: refactor: Refactor utils/json (#1659) * refactor: Refactor utils/json * Update kombu/utils/json.py * Update kombu/utils/json.py * chore: Use older syntax (no walrus) * chore: Update doscstrings * chore: Fix pydocstyle complaints * chore: Restore previous docstring --- kombu/utils/json.py | 167 ++++++++++++++++++++++++++-------------------- t/unit/utils/test_json.py | 34 ++++------ 2 files changed, 107 insertions(+), 94 deletions(-) diff --git a/kombu/utils/json.py b/kombu/utils/json.py index 95a05a68..ec6269e2 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -3,88 +3,70 @@ from __future__ import annotations import base64 -import datetime -import decimal import json import uuid +from datetime import date, datetime, time +from decimal import Decimal +from typing import Any, Callable, TypeVar -try: - from django.utils.functional import Promise as DjangoPromise -except ImportError: # pragma: no cover - class DjangoPromise: - """Dummy object.""" +textual_types = () +try: + from django.utils.functional import Promise -class _DecodeError(Exception): + textual_types += (Promise,) +except ImportError: pass -_encoder_cls = type(json._default_encoder) -_default_encoder = None # ... set to JSONEncoder below. - - -class JSONEncoder(_encoder_cls): +class JSONEncoder(json.JSONEncoder): """Kombu custom json encoder.""" - def default(self, o, - dates=(datetime.datetime, datetime.date), - times=(datetime.time,), - textual=(decimal.Decimal, DjangoPromise), - isinstance=isinstance, - datetime=datetime.datetime, - text_t=str): - reducer = getattr(o, '__json__', None) + def default(self, o): + reducer = getattr(o, "__json__", None) if reducer is not None: return reducer() - else: - if isinstance(o, dates): - marker = "__date__" - if not isinstance(o, datetime): - o = datetime(o.year, o.month, o.day, 0, 0, 0, 0) - else: - marker = "__datetime__" - r = o.isoformat() - return {"datetime": r, marker: True} - elif isinstance(o, times): - return o.isoformat() - elif isinstance(o, uuid.UUID): - return {"uuid": str(o), "__uuid__": True, "version": o.version} - elif isinstance(o, textual): - return text_t(o) - elif isinstance(o, bytes): - try: - return {"bytes": o.decode("utf-8"), "__bytes__": True} - except UnicodeDecodeError: - return { - "bytes": base64.b64encode(o).decode("utf-8"), - "__base64__": True, - } - return super().default(o) - - -_default_encoder = JSONEncoder - - -def dumps(s, _dumps=json.dumps, cls=None, default_kwargs=None, **kwargs): + + if isinstance(o, textual_types): + return str(o) + + for t, (marker, encoder) in _encoders.items(): + if isinstance(o, t): + return _as(marker, encoder(o)) + + # Bytes is slightly trickier, so we cannot put them directly + # into _encoders, because we use two formats: bytes, and base64. + if isinstance(o, bytes): + try: + return _as("bytes", o.decode("utf-8")) + except UnicodeDecodeError: + return _as("base64", base64.b64encode(o).decode("utf-8")) + + return super().default(o) + + +def _as(t: str, v: Any): + return {"__type__": t, "__value__": v} + + +def dumps( + s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs +): """Serialize object to json string.""" default_kwargs = default_kwargs or {} - return _dumps(s, cls=cls or _default_encoder, - **dict(default_kwargs, **kwargs)) + return _dumps(s, cls=cls, **dict(default_kwargs, **kwargs)) -def object_hook(dct): +def object_hook(o: dict): """Hook function to perform custom deserialization.""" - if "__date__" in dct: - return datetime.datetime.fromisoformat(dct["datetime"]).date() - if "__datetime__" in dct: - return datetime.datetime.fromisoformat(dct["datetime"]) - if "__bytes__" in dct: - return dct["bytes"].encode("utf-8") - if "__base64__" in dct: - return base64.b64decode(dct["bytes"].encode("utf-8")) - if "__uuid__" in dct: - return uuid.UUID(dct["uuid"], version=dct["version"]) - return dct + if o.keys() == {"__type__", "__value__"}: + decoder = _decoders.get(o["__type__"]) + if decoder: + return decoder(o["__value__"]) + else: + raise ValueError("Unsupported type", type, o) + else: + return o def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): @@ -96,14 +78,51 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): # over. Note that pickle does support buffer/memoryview # if isinstance(s, memoryview): - s = s.tobytes().decode('utf-8') + s = s.tobytes().decode("utf-8") elif isinstance(s, bytearray): - s = s.decode('utf-8') + s = s.decode("utf-8") elif decode_bytes and isinstance(s, bytes): - s = s.decode('utf-8') - - try: - return _loads(s, object_hook=object_hook) - except _DecodeError: - # catch "Unpaired high surrogate" error - return json.loads(s) + s = s.decode("utf-8") + + return _loads(s, object_hook=object_hook) + + +DecoderT = EncoderT = Callable[[Any], Any] +T = TypeVar("T") +EncodedT = TypeVar("EncodedT") + + +def register_type( + t: type[T], + marker: str, + encoder: Callable[[T], EncodedT], + decoder: Callable[[EncodedT], T], +): + """Add support for serializing/deserializing native python type.""" + _encoders[t] = (marker, encoder) + _decoders[marker] = decoder + + +_encoders: dict[type, tuple[str, EncoderT]] = {} +_decoders: dict[str, DecoderT] = { + "bytes": lambda o: o.encode("utf-8"), + "base64": lambda o: base64.b64decode(o.encode("utf-8")), +} + +# NOTE: datetime should be registered before date, +# because datetime is also instance of date. +register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat) +register_type( + date, + "date", + lambda o: o.isoformat(), + lambda o: datetime.fromisoformat(o).date(), +) +register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) +register_type(Decimal, "decimal", str, Decimal) +register_type( + uuid.UUID, + "uuid", + lambda o: {"hex": o.hex, "version": o.version}, + lambda o: uuid.UUID(**o), +) diff --git a/t/unit/utils/test_json.py b/t/unit/utils/test_json.py index ec7ea102..8dcc7e32 100644 --- a/t/unit/utils/test_json.py +++ b/t/unit/utils/test_json.py @@ -2,9 +2,8 @@ from __future__ import annotations import uuid from collections import namedtuple -from datetime import date, datetime +from datetime import datetime from decimal import Decimal -from unittest.mock import MagicMock, Mock import pytest import pytz @@ -12,7 +11,7 @@ from hypothesis import given, settings from hypothesis import strategies as st from kombu.utils.encoding import str_to_bytes -from kombu.utils.json import _DecodeError, dumps, loads +from kombu.utils.json import dumps, loads class Custom: @@ -29,19 +28,18 @@ class test_JSONEncoder: def test_datetime(self): now = datetime.utcnow() now_utc = now.replace(tzinfo=pytz.utc) - serialized = loads(dumps({ + + original = { 'datetime': now, 'tz': now_utc, 'date': now.date(), - 'time': now.time()}, - )) - assert serialized == { - 'datetime': now, - 'tz': now_utc, - 'time': now.time().isoformat(), - 'date': date(now.year, now.month, now.day), + 'time': now.time(), } + serialized = loads(dumps(original)) + + assert serialized == original + @given(message=st.binary()) @settings(print_blob=True) def test_binary(self, message): @@ -53,8 +51,10 @@ class test_JSONEncoder: } def test_Decimal(self): - d = Decimal('3314132.13363235235324234123213213214134') - assert loads(dumps({'d': d})), {'d': str(d)} + original = {'d': Decimal('3314132.13363235235324234123213213214134')} + serialized = loads(dumps(original)) + + assert serialized == original def test_namedtuple(self): Foo = namedtuple('Foo', ['bar']) @@ -70,7 +70,7 @@ class test_JSONEncoder: for constructor in constructors: id = constructor() loaded_value = loads(dumps({'u': id})) - assert loaded_value, {'u': id} + assert loaded_value == {'u': id} assert loaded_value["u"].version == id.version def test_default(self): @@ -103,9 +103,3 @@ class test_dumps_loads: assert loads( str_to_bytes(dumps({'x': 'z'})), decode_bytes=True) == {'x': 'z'} - - def test_loads_DecodeError(self): - _loads = Mock(name='_loads') - _loads.side_effect = _DecodeError( - MagicMock(), MagicMock(), MagicMock()) - assert loads(dumps({'x': 'z'}), _loads=_loads) == {'x': 'z'} -- cgit v1.2.1