summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSerhii Tereshchenko <serg.partizan@gmail.com>2023-03-02 12:18:21 +0200
committerGitHub <noreply@github.com>2023-03-02 16:18:21 +0600
commit6fe50f5985b01c1f9b172f740c51400540bfc92d (patch)
tree50cef4d449b3e15064061f17521c12fea6893bf2
parentb4c033d5567d71423b4144accaaa6887b532d693 (diff)
downloadkombu-6fe50f5985b01c1f9b172f740c51400540bfc92d.tar.gz
refactor: Refactor utils/json (#1659)
* refactor: Refactor utils/json * Update kombu/utils/json.py * Update kombu/utils/json.py * chore: Use older syntax (no walrus) * chore: Update doscstrings * chore: Fix pydocstyle complaints * chore: Restore previous docstring
-rw-r--r--kombu/utils/json.py167
-rw-r--r--t/unit/utils/test_json.py34
2 files changed, 107 insertions, 94 deletions
diff --git a/kombu/utils/json.py b/kombu/utils/json.py
index 95a05a68..ec6269e2 100644
--- a/kombu/utils/json.py
+++ b/kombu/utils/json.py
@@ -3,88 +3,70 @@
from __future__ import annotations
import base64
-import datetime
-import decimal
import json
import uuid
+from datetime import date, datetime, time
+from decimal import Decimal
+from typing import Any, Callable, TypeVar
-try:
- from django.utils.functional import Promise as DjangoPromise
-except ImportError: # pragma: no cover
- class DjangoPromise:
- """Dummy object."""
+textual_types = ()
+try:
+ from django.utils.functional import Promise
-class _DecodeError(Exception):
+ textual_types += (Promise,)
+except ImportError:
pass
-_encoder_cls = type(json._default_encoder)
-_default_encoder = None # ... set to JSONEncoder below.
-
-
-class JSONEncoder(_encoder_cls):
+class JSONEncoder(json.JSONEncoder):
"""Kombu custom json encoder."""
- def default(self, o,
- dates=(datetime.datetime, datetime.date),
- times=(datetime.time,),
- textual=(decimal.Decimal, DjangoPromise),
- isinstance=isinstance,
- datetime=datetime.datetime,
- text_t=str):
- reducer = getattr(o, '__json__', None)
+ def default(self, o):
+ reducer = getattr(o, "__json__", None)
if reducer is not None:
return reducer()
- else:
- if isinstance(o, dates):
- marker = "__date__"
- if not isinstance(o, datetime):
- o = datetime(o.year, o.month, o.day, 0, 0, 0, 0)
- else:
- marker = "__datetime__"
- r = o.isoformat()
- return {"datetime": r, marker: True}
- elif isinstance(o, times):
- return o.isoformat()
- elif isinstance(o, uuid.UUID):
- return {"uuid": str(o), "__uuid__": True, "version": o.version}
- elif isinstance(o, textual):
- return text_t(o)
- elif isinstance(o, bytes):
- try:
- return {"bytes": o.decode("utf-8"), "__bytes__": True}
- except UnicodeDecodeError:
- return {
- "bytes": base64.b64encode(o).decode("utf-8"),
- "__base64__": True,
- }
- return super().default(o)
-
-
-_default_encoder = JSONEncoder
-
-
-def dumps(s, _dumps=json.dumps, cls=None, default_kwargs=None, **kwargs):
+
+ if isinstance(o, textual_types):
+ return str(o)
+
+ for t, (marker, encoder) in _encoders.items():
+ if isinstance(o, t):
+ return _as(marker, encoder(o))
+
+ # Bytes is slightly trickier, so we cannot put them directly
+ # into _encoders, because we use two formats: bytes, and base64.
+ if isinstance(o, bytes):
+ try:
+ return _as("bytes", o.decode("utf-8"))
+ except UnicodeDecodeError:
+ return _as("base64", base64.b64encode(o).decode("utf-8"))
+
+ return super().default(o)
+
+
+def _as(t: str, v: Any):
+ return {"__type__": t, "__value__": v}
+
+
+def dumps(
+ s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs
+):
"""Serialize object to json string."""
default_kwargs = default_kwargs or {}
- return _dumps(s, cls=cls or _default_encoder,
- **dict(default_kwargs, **kwargs))
+ return _dumps(s, cls=cls, **dict(default_kwargs, **kwargs))
-def object_hook(dct):
+def object_hook(o: dict):
"""Hook function to perform custom deserialization."""
- if "__date__" in dct:
- return datetime.datetime.fromisoformat(dct["datetime"]).date()
- if "__datetime__" in dct:
- return datetime.datetime.fromisoformat(dct["datetime"])
- if "__bytes__" in dct:
- return dct["bytes"].encode("utf-8")
- if "__base64__" in dct:
- return base64.b64decode(dct["bytes"].encode("utf-8"))
- if "__uuid__" in dct:
- return uuid.UUID(dct["uuid"], version=dct["version"])
- return dct
+ if o.keys() == {"__type__", "__value__"}:
+ decoder = _decoders.get(o["__type__"])
+ if decoder:
+ return decoder(o["__value__"])
+ else:
+ raise ValueError("Unsupported type", type, o)
+ else:
+ return o
def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
@@ -96,14 +78,51 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
# over. Note that pickle does support buffer/memoryview
# </rant>
if isinstance(s, memoryview):
- s = s.tobytes().decode('utf-8')
+ s = s.tobytes().decode("utf-8")
elif isinstance(s, bytearray):
- s = s.decode('utf-8')
+ s = s.decode("utf-8")
elif decode_bytes and isinstance(s, bytes):
- s = s.decode('utf-8')
-
- try:
- return _loads(s, object_hook=object_hook)
- except _DecodeError:
- # catch "Unpaired high surrogate" error
- return json.loads(s)
+ s = s.decode("utf-8")
+
+ return _loads(s, object_hook=object_hook)
+
+
+DecoderT = EncoderT = Callable[[Any], Any]
+T = TypeVar("T")
+EncodedT = TypeVar("EncodedT")
+
+
+def register_type(
+ t: type[T],
+ marker: str,
+ encoder: Callable[[T], EncodedT],
+ decoder: Callable[[EncodedT], T],
+):
+ """Add support for serializing/deserializing native python type."""
+ _encoders[t] = (marker, encoder)
+ _decoders[marker] = decoder
+
+
+_encoders: dict[type, tuple[str, EncoderT]] = {}
+_decoders: dict[str, DecoderT] = {
+ "bytes": lambda o: o.encode("utf-8"),
+ "base64": lambda o: base64.b64decode(o.encode("utf-8")),
+}
+
+# NOTE: datetime should be registered before date,
+# because datetime is also instance of date.
+register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat)
+register_type(
+ date,
+ "date",
+ lambda o: o.isoformat(),
+ lambda o: datetime.fromisoformat(o).date(),
+)
+register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
+register_type(Decimal, "decimal", str, Decimal)
+register_type(
+ uuid.UUID,
+ "uuid",
+ lambda o: {"hex": o.hex, "version": o.version},
+ lambda o: uuid.UUID(**o),
+)
diff --git a/t/unit/utils/test_json.py b/t/unit/utils/test_json.py
index ec7ea102..8dcc7e32 100644
--- a/t/unit/utils/test_json.py
+++ b/t/unit/utils/test_json.py
@@ -2,9 +2,8 @@ from __future__ import annotations
import uuid
from collections import namedtuple
-from datetime import date, datetime
+from datetime import datetime
from decimal import Decimal
-from unittest.mock import MagicMock, Mock
import pytest
import pytz
@@ -12,7 +11,7 @@ from hypothesis import given, settings
from hypothesis import strategies as st
from kombu.utils.encoding import str_to_bytes
-from kombu.utils.json import _DecodeError, dumps, loads
+from kombu.utils.json import dumps, loads
class Custom:
@@ -29,19 +28,18 @@ class test_JSONEncoder:
def test_datetime(self):
now = datetime.utcnow()
now_utc = now.replace(tzinfo=pytz.utc)
- serialized = loads(dumps({
+
+ original = {
'datetime': now,
'tz': now_utc,
'date': now.date(),
- 'time': now.time()},
- ))
- assert serialized == {
- 'datetime': now,
- 'tz': now_utc,
- 'time': now.time().isoformat(),
- 'date': date(now.year, now.month, now.day),
+ 'time': now.time(),
}
+ serialized = loads(dumps(original))
+
+ assert serialized == original
+
@given(message=st.binary())
@settings(print_blob=True)
def test_binary(self, message):
@@ -53,8 +51,10 @@ class test_JSONEncoder:
}
def test_Decimal(self):
- d = Decimal('3314132.13363235235324234123213213214134')
- assert loads(dumps({'d': d})), {'d': str(d)}
+ original = {'d': Decimal('3314132.13363235235324234123213213214134')}
+ serialized = loads(dumps(original))
+
+ assert serialized == original
def test_namedtuple(self):
Foo = namedtuple('Foo', ['bar'])
@@ -70,7 +70,7 @@ class test_JSONEncoder:
for constructor in constructors:
id = constructor()
loaded_value = loads(dumps({'u': id}))
- assert loaded_value, {'u': id}
+ assert loaded_value == {'u': id}
assert loaded_value["u"].version == id.version
def test_default(self):
@@ -103,9 +103,3 @@ class test_dumps_loads:
assert loads(
str_to_bytes(dumps({'x': 'z'})),
decode_bytes=True) == {'x': 'z'}
-
- def test_loads_DecodeError(self):
- _loads = Mock(name='_loads')
- _loads.side_effect = _DecodeError(
- MagicMock(), MagicMock(), MagicMock())
- assert loads(dumps({'x': 'z'}), _loads=_loads) == {'x': 'z'}