summaryrefslogtreecommitdiff
path: root/django/db/models/fields/json.py
blob: eb2d35f100587717fdc3b60196dc099d2cd54fa2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
import json
import warnings

from django import forms
from django.core import checks, exceptions
from django.db import NotSupportedError, connections, router
from django.db.models import expressions, lookups
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import TextField
from django.db.models.lookups import (
    FieldGetDbPrepValueMixin,
    PostgresOperatorLookup,
    Transform,
)
from django.utils.deprecation import RemovedInDjango51Warning
from django.utils.translation import gettext_lazy as _

from . import Field
from .mixins import CheckFieldDefaultMixin

__all__ = ["JSONField"]


class JSONField(CheckFieldDefaultMixin, Field):
    empty_strings_allowed = False
    description = _("A JSON object")
    default_error_messages = {
        "invalid": _("Value must be valid JSON."),
    }
    _default_hint = ("dict", "{}")

    def __init__(
        self,
        verbose_name=None,
        name=None,
        encoder=None,
        decoder=None,
        **kwargs,
    ):
        if encoder and not callable(encoder):
            raise ValueError("The encoder parameter must be a callable object.")
        if decoder and not callable(decoder):
            raise ValueError("The decoder parameter must be a callable object.")
        self.encoder = encoder
        self.decoder = decoder
        super().__init__(verbose_name, name, **kwargs)

    def check(self, **kwargs):
        errors = super().check(**kwargs)
        databases = kwargs.get("databases") or []
        errors.extend(self._check_supported(databases))
        return errors

    def _check_supported(self, databases):
        errors = []
        for db in databases:
            if not router.allow_migrate_model(db, self.model):
                continue
            connection = connections[db]
            if (
                self.model._meta.required_db_vendor
                and self.model._meta.required_db_vendor != connection.vendor
            ):
                continue
            if not (
                "supports_json_field" in self.model._meta.required_db_features
                or connection.features.supports_json_field
            ):
                errors.append(
                    checks.Error(
                        "%s does not support JSONFields." % connection.display_name,
                        obj=self.model,
                        id="fields.E180",
                    )
                )
        return errors

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        if self.encoder is not None:
            kwargs["encoder"] = self.encoder
        if self.decoder is not None:
            kwargs["decoder"] = self.decoder
        return name, path, args, kwargs

    def from_db_value(self, value, expression, connection):
        if value is None:
            return value
        # Some backends (SQLite at least) extract non-string values in their
        # SQL datatypes.
        if isinstance(expression, KeyTransform) and not isinstance(value, str):
            return value
        try:
            return json.loads(value, cls=self.decoder)
        except json.JSONDecodeError:
            return value

    def get_internal_type(self):
        return "JSONField"

    def get_db_prep_value(self, value, connection, prepared=False):
        # RemovedInDjango51Warning: When the deprecation ends, replace with:
        # if (
        #     isinstance(value, expressions.Value)
        #     and isinstance(value.output_field, JSONField)
        # ):
        #     value = value.value
        # elif hasattr(value, "as_sql"): ...
        if isinstance(value, expressions.Value):
            if isinstance(value.value, str) and not isinstance(
                value.output_field, JSONField
            ):
                try:
                    value = json.loads(value.value, cls=self.decoder)
                except json.JSONDecodeError:
                    value = value.value
                else:
                    warnings.warn(
                        "Providing an encoded JSON string via Value() is deprecated. "
                        f"Use Value({value!r}, output_field=JSONField()) instead.",
                        category=RemovedInDjango51Warning,
                    )
            elif isinstance(value.output_field, JSONField):
                value = value.value
            else:
                return value
        elif hasattr(value, "as_sql"):
            return value
        return connection.ops.adapt_json_value(value, self.encoder)

    def get_db_prep_save(self, value, connection):
        if value is None:
            return value
        return self.get_db_prep_value(value, connection)

    def get_transform(self, name):
        transform = super().get_transform(name)
        if transform:
            return transform
        return KeyTransformFactory(name)

    def validate(self, value, model_instance):
        super().validate(value, model_instance)
        try:
            json.dumps(value, cls=self.encoder)
        except TypeError:
            raise exceptions.ValidationError(
                self.error_messages["invalid"],
                code="invalid",
                params={"value": value},
            )

    def value_to_string(self, obj):
        return self.value_from_object(obj)

    def formfield(self, **kwargs):
        return super().formfield(
            **{
                "form_class": forms.JSONField,
                "encoder": self.encoder,
                "decoder": self.decoder,
                **kwargs,
            }
        )


def compile_json_path(key_transforms, include_root=True):
    path = ["$"] if include_root else []
    for key_transform in key_transforms:
        try:
            num = int(key_transform)
        except ValueError:  # non-integer
            path.append(".")
            path.append(json.dumps(key_transform))
        else:
            path.append("[%s]" % num)
    return "".join(path)


class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
    lookup_name = "contains"
    postgres_operator = "@>"

    def as_sql(self, compiler, connection):
        if not connection.features.supports_json_field_contains:
            raise NotSupportedError(
                "contains lookup is not supported on this database backend."
            )
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = tuple(lhs_params) + tuple(rhs_params)
        return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params


class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
    lookup_name = "contained_by"
    postgres_operator = "<@"

    def as_sql(self, compiler, connection):
        if not connection.features.supports_json_field_contains:
            raise NotSupportedError(
                "contained_by lookup is not supported on this database backend."
            )
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = tuple(rhs_params) + tuple(lhs_params)
        return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params


class HasKeyLookup(PostgresOperatorLookup):
    logical_operator = None

    def compile_json_path_final_key(self, key_transform):
        # Compile the final key without interpreting ints as array elements.
        return ".%s" % json.dumps(key_transform)

    def as_sql(self, compiler, connection, template=None):
        # Process JSON path from the left-hand side.
        if isinstance(self.lhs, KeyTransform):
            lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
                compiler, connection
            )
            lhs_json_path = compile_json_path(lhs_key_transforms)
        else:
            lhs, lhs_params = self.process_lhs(compiler, connection)
            lhs_json_path = "$"
        sql = template % lhs
        # Process JSON path from the right-hand side.
        rhs = self.rhs
        rhs_params = []
        if not isinstance(rhs, (list, tuple)):
            rhs = [rhs]
        for key in rhs:
            if isinstance(key, KeyTransform):
                *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
            else:
                rhs_key_transforms = [key]
            *rhs_key_transforms, final_key = rhs_key_transforms
            rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
            rhs_json_path += self.compile_json_path_final_key(final_key)
            rhs_params.append(lhs_json_path + rhs_json_path)
        # Add condition for each key.
        if self.logical_operator:
            sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
        return sql, tuple(lhs_params) + tuple(rhs_params)

    def as_mysql(self, compiler, connection):
        return self.as_sql(
            compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
        )

    def as_oracle(self, compiler, connection):
        sql, params = self.as_sql(
            compiler, connection, template="JSON_EXISTS(%s, '%%s')"
        )
        # Add paths directly into SQL because path expressions cannot be passed
        # as bind variables on Oracle.
        return sql % tuple(params), []

    def as_postgresql(self, compiler, connection):
        if isinstance(self.rhs, KeyTransform):
            *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
            for key in rhs_key_transforms[:-1]:
                self.lhs = KeyTransform(key, self.lhs)
            self.rhs = rhs_key_transforms[-1]
        return super().as_postgresql(compiler, connection)

    def as_sqlite(self, compiler, connection):
        return self.as_sql(
            compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
        )


class HasKey(HasKeyLookup):
    lookup_name = "has_key"
    postgres_operator = "?"
    prepare_rhs = False


class HasKeys(HasKeyLookup):
    lookup_name = "has_keys"
    postgres_operator = "?&"
    logical_operator = " AND "

    def get_prep_lookup(self):
        return [str(item) for item in self.rhs]


class HasAnyKeys(HasKeys):
    lookup_name = "has_any_keys"
    postgres_operator = "?|"
    logical_operator = " OR "


class HasKeyOrArrayIndex(HasKey):
    def compile_json_path_final_key(self, key_transform):
        return compile_json_path([key_transform], include_root=False)


class CaseInsensitiveMixin:
    """
    Mixin to allow case-insensitive comparison of JSON values on MySQL.
    MySQL handles strings used in JSON context using the utf8mb4_bin collation.
    Because utf8mb4_bin is a binary collation, comparison of JSON values is
    case-sensitive.
    """

    def process_lhs(self, compiler, connection):
        lhs, lhs_params = super().process_lhs(compiler, connection)
        if connection.vendor == "mysql":
            return "LOWER(%s)" % lhs, lhs_params
        return lhs, lhs_params

    def process_rhs(self, compiler, connection):
        rhs, rhs_params = super().process_rhs(compiler, connection)
        if connection.vendor == "mysql":
            return "LOWER(%s)" % rhs, rhs_params
        return rhs, rhs_params


class JSONExact(lookups.Exact):
    can_use_none_as_rhs = True

    def process_rhs(self, compiler, connection):
        rhs, rhs_params = super().process_rhs(compiler, connection)
        # Treat None lookup values as null.
        if rhs == "%s" and rhs_params == [None]:
            rhs_params = ["null"]
        if connection.vendor == "mysql":
            func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
            rhs %= tuple(func)
        return rhs, rhs_params


class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
    pass


JSONField.register_lookup(DataContains)
JSONField.register_lookup(ContainedBy)
JSONField.register_lookup(HasKey)
JSONField.register_lookup(HasKeys)
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)
JSONField.register_lookup(JSONIContains)


class KeyTransform(Transform):
    postgres_operator = "->"
    postgres_nested_operator = "#>"

    def __init__(self, key_name, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.key_name = str(key_name)

    def preprocess_lhs(self, compiler, connection):
        key_transforms = [self.key_name]
        previous = self.lhs
        while isinstance(previous, KeyTransform):
            key_transforms.insert(0, previous.key_name)
            previous = previous.lhs
        lhs, params = compiler.compile(previous)
        if connection.vendor == "oracle":
            # Escape string-formatting.
            key_transforms = [key.replace("%", "%%") for key in key_transforms]
        return lhs, params, key_transforms

    def as_mysql(self, compiler, connection):
        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
        json_path = compile_json_path(key_transforms)
        return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)

    def as_oracle(self, compiler, connection):
        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
        json_path = compile_json_path(key_transforms)
        return (
            "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
            % ((lhs, json_path) * 2)
        ), tuple(params) * 2

    def as_postgresql(self, compiler, connection):
        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
        if len(key_transforms) > 1:
            sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
            return sql, tuple(params) + (key_transforms,)
        try:
            lookup = int(self.key_name)
        except ValueError:
            lookup = self.key_name
        return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)

    def as_sqlite(self, compiler, connection):
        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
        json_path = compile_json_path(key_transforms)
        datatype_values = ",".join(
            [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
        )
        return (
            "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
            "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
        ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3


class KeyTextTransform(KeyTransform):
    postgres_operator = "->>"
    postgres_nested_operator = "#>>"
    output_field = TextField()

    def as_mysql(self, compiler, connection):
        if connection.mysql_is_mariadb:
            # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
            sql, params = super().as_mysql(compiler, connection)
            return "JSON_UNQUOTE(%s)" % sql, params
        else:
            lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
            json_path = compile_json_path(key_transforms)
            return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)

    @classmethod
    def from_lookup(cls, lookup):
        transform, *keys = lookup.split(LOOKUP_SEP)
        if not keys:
            raise ValueError("Lookup must contain key or index transforms.")
        for key in keys:
            transform = cls(key, transform)
        return transform


KT = KeyTextTransform.from_lookup


class KeyTransformTextLookupMixin:
    """
    Mixin for combining with a lookup expecting a text lhs from a JSONField
    key lookup. On PostgreSQL, make use of the ->> operator instead of casting
    key values to text and performing the lookup on the resulting
    representation.
    """

    def __init__(self, key_transform, *args, **kwargs):
        if not isinstance(key_transform, KeyTransform):
            raise TypeError(
                "Transform should be an instance of KeyTransform in order to "
                "use this lookup."
            )
        key_text_transform = KeyTextTransform(
            key_transform.key_name,
            *key_transform.source_expressions,
            **key_transform.extra,
        )
        super().__init__(key_text_transform, *args, **kwargs)


class KeyTransformIsNull(lookups.IsNull):
    # key__isnull=False is the same as has_key='key'
    def as_oracle(self, compiler, connection):
        sql, params = HasKeyOrArrayIndex(
            self.lhs.lhs,
            self.lhs.key_name,
        ).as_oracle(compiler, connection)
        if not self.rhs:
            return sql, params
        # Column doesn't have a key or IS NULL.
        lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
        return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)

    def as_sqlite(self, compiler, connection):
        template = "JSON_TYPE(%s, %%s) IS NULL"
        if not self.rhs:
            template = "JSON_TYPE(%s, %%s) IS NOT NULL"
        return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
            compiler,
            connection,
            template=template,
        )


class KeyTransformIn(lookups.In):
    def resolve_expression_parameter(self, compiler, connection, sql, param):
        sql, params = super().resolve_expression_parameter(
            compiler,
            connection,
            sql,
            param,
        )
        if (
            not hasattr(param, "as_sql")
            and not connection.features.has_native_json_field
        ):
            if connection.vendor == "oracle":
                value = json.loads(param)
                sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
                if isinstance(value, (list, dict)):
                    sql %= "JSON_QUERY"
                else:
                    sql %= "JSON_VALUE"
            elif connection.vendor == "mysql" or (
                connection.vendor == "sqlite"
                and params[0] not in connection.ops.jsonfield_datatype_values
            ):
                sql = "JSON_EXTRACT(%s, '$')"
        if connection.vendor == "mysql" and connection.mysql_is_mariadb:
            sql = "JSON_UNQUOTE(%s)" % sql
        return sql, params


class KeyTransformExact(JSONExact):
    def process_rhs(self, compiler, connection):
        if isinstance(self.rhs, KeyTransform):
            return super(lookups.Exact, self).process_rhs(compiler, connection)
        rhs, rhs_params = super().process_rhs(compiler, connection)
        if connection.vendor == "oracle":
            func = []
            sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
            for value in rhs_params:
                value = json.loads(value)
                if isinstance(value, (list, dict)):
                    func.append(sql % "JSON_QUERY")
                else:
                    func.append(sql % "JSON_VALUE")
            rhs %= tuple(func)
        elif connection.vendor == "sqlite":
            func = []
            for value in rhs_params:
                if value in connection.ops.jsonfield_datatype_values:
                    func.append("%s")
                else:
                    func.append("JSON_EXTRACT(%s, '$')")
            rhs %= tuple(func)
        return rhs, rhs_params

    def as_oracle(self, compiler, connection):
        rhs, rhs_params = super().process_rhs(compiler, connection)
        if rhs_params == ["null"]:
            # Field has key and it's NULL.
            has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
            has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
            is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
            is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
            return (
                "%s AND %s" % (has_key_sql, is_null_sql),
                tuple(has_key_params) + tuple(is_null_params),
            )
        return super().as_sql(compiler, connection)


class KeyTransformIExact(
    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
):
    pass


class KeyTransformIContains(
    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
):
    pass


class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
    pass


class KeyTransformIStartsWith(
    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
):
    pass


class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
    pass


class KeyTransformIEndsWith(
    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
):
    pass


class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
    pass


class KeyTransformIRegex(
    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
):
    pass


class KeyTransformNumericLookupMixin:
    def process_rhs(self, compiler, connection):
        rhs, rhs_params = super().process_rhs(compiler, connection)
        if not connection.features.has_native_json_field:
            rhs_params = [json.loads(value) for value in rhs_params]
        return rhs, rhs_params


class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
    pass


class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
    pass


class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
    pass


class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
    pass


KeyTransform.register_lookup(KeyTransformIn)
KeyTransform.register_lookup(KeyTransformExact)
KeyTransform.register_lookup(KeyTransformIExact)
KeyTransform.register_lookup(KeyTransformIsNull)
KeyTransform.register_lookup(KeyTransformIContains)
KeyTransform.register_lookup(KeyTransformStartsWith)
KeyTransform.register_lookup(KeyTransformIStartsWith)
KeyTransform.register_lookup(KeyTransformEndsWith)
KeyTransform.register_lookup(KeyTransformIEndsWith)
KeyTransform.register_lookup(KeyTransformRegex)
KeyTransform.register_lookup(KeyTransformIRegex)

KeyTransform.register_lookup(KeyTransformLt)
KeyTransform.register_lookup(KeyTransformLte)
KeyTransform.register_lookup(KeyTransformGt)
KeyTransform.register_lookup(KeyTransformGte)


class KeyTransformFactory:
    def __init__(self, key_name):
        self.key_name = key_name

    def __call__(self, *args, **kwargs):
        return KeyTransform(self.key_name, *args, **kwargs)