diff options
author | MichaĆ Modzelewski <michal.modzelewski@gmail.com> | 2015-01-02 02:39:31 +0100 |
---|---|---|
committer | Tim Graham <timograham@gmail.com> | 2015-01-12 18:15:34 -0500 |
commit | 65246de7b1d70d25831ab394c4f4a75813f629fe (patch) | |
tree | 618c5f030f9a77d240dc59b132dd1e152baca116 /tests/expressions_case | |
parent | aa8ee6a5731b37b73635e7605521fb1a54a5c10d (diff) | |
download | django-65246de7b1d70d25831ab394c4f4a75813f629fe.tar.gz |
Fixed #24031 -- Added CASE expressions to the ORM.
Diffstat (limited to 'tests/expressions_case')
-rw-r--r-- | tests/expressions_case/__init__.py | 0 | ||||
-rw-r--r-- | tests/expressions_case/models.py | 80 | ||||
-rw-r--r-- | tests/expressions_case/tests.py | 1083 |
3 files changed, 1163 insertions, 0 deletions
diff --git a/tests/expressions_case/__init__.py b/tests/expressions_case/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/expressions_case/__init__.py diff --git a/tests/expressions_case/models.py b/tests/expressions_case/models.py new file mode 100644 index 0000000000..fcee0cdbd3 --- /dev/null +++ b/tests/expressions_case/models.py @@ -0,0 +1,80 @@ +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class CaseTestModel(models.Model): + integer = models.IntegerField() + integer2 = models.IntegerField(null=True) + string = models.CharField(max_length=100, default='') + + big_integer = models.BigIntegerField(null=True) + binary = models.BinaryField(default=b'') + boolean = models.BooleanField(default=False) + comma_separated_integer = models.CommaSeparatedIntegerField(max_length=100, default='') + date = models.DateField(null=True, db_column='date_field') + date_time = models.DateTimeField(null=True) + decimal = models.DecimalField(max_digits=2, decimal_places=1, null=True, db_column='decimal_field') + duration = models.DurationField(null=True) + email = models.EmailField(default='') + file = models.FileField(null=True, db_column='file_field') + file_path = models.FilePathField(null=True) + float = models.FloatField(null=True, db_column='float_field') + image = models.ImageField(null=True) + ip_address = models.IPAddressField(null=True) + generic_ip_address = models.GenericIPAddressField(null=True) + null_boolean = models.NullBooleanField() + positive_integer = models.PositiveIntegerField(null=True) + positive_small_integer = models.PositiveSmallIntegerField(null=True) + slug = models.SlugField(default='') + small_integer = models.SmallIntegerField(null=True) + text = models.TextField(default='') + time = models.TimeField(null=True, db_column='time_field') + url = models.URLField(default='') + uuid = models.UUIDField(null=True) + fk = models.ForeignKey('self', null=True) + + def __str__(self): + return "%i, %s" % (self.integer, self.string) + + +@python_2_unicode_compatible +class O2OCaseTestModel(models.Model): + o2o = models.OneToOneField(CaseTestModel, related_name='o2o_rel') + integer = models.IntegerField() + + def __str__(self): + return "%i, %s" % (self.id, self.o2o) + + +@python_2_unicode_compatible +class FKCaseTestModel(models.Model): + fk = models.ForeignKey(CaseTestModel, related_name='fk_rel') + integer = models.IntegerField() + + def __str__(self): + return "%i, %s" % (self.id, self.fk) + + +@python_2_unicode_compatible +class Client(models.Model): + REGULAR = 'R' + GOLD = 'G' + PLATINUM = 'P' + ACCOUNT_TYPE_CHOICES = ( + (REGULAR, 'Regular'), + (GOLD, 'Gold'), + (PLATINUM, 'Platinum'), + ) + name = models.CharField(max_length=50) + registered_on = models.DateField() + account_type = models.CharField( + max_length=1, + choices=ACCOUNT_TYPE_CHOICES, + default=REGULAR, + ) + + def __str__(self): + return self.name diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py new file mode 100644 index 0000000000..20bdb840fa --- /dev/null +++ b/tests/expressions_case/tests.py @@ -0,0 +1,1083 @@ +from __future__ import unicode_literals + +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from operator import attrgetter, itemgetter +from uuid import UUID + +from django.core.exceptions import FieldError +from django.db import models +from django.db.models import F, Q, Value, Min, Max +from django.db.models.expressions import Case, When +from django.test import TestCase +from django.utils import six + +from .models import CaseTestModel, O2OCaseTestModel, FKCaseTestModel, Client + + +class CaseExpressionTests(TestCase): + @classmethod + def setUpTestData(cls): + o = CaseTestModel.objects.create(integer=1, integer2=1, string='1') + O2OCaseTestModel.objects.create(o2o=o, integer=1) + FKCaseTestModel.objects.create(fk=o, integer=1) + + o = CaseTestModel.objects.create(integer=2, integer2=3, string='2') + O2OCaseTestModel.objects.create(o2o=o, integer=2) + FKCaseTestModel.objects.create(fk=o, integer=2) + FKCaseTestModel.objects.create(fk=o, integer=3) + + o = CaseTestModel.objects.create(integer=3, integer2=4, string='3') + O2OCaseTestModel.objects.create(o2o=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=4) + + o = CaseTestModel.objects.create(integer=2, integer2=2, string='2') + O2OCaseTestModel.objects.create(o2o=o, integer=2) + FKCaseTestModel.objects.create(fk=o, integer=2) + FKCaseTestModel.objects.create(fk=o, integer=3) + + o = CaseTestModel.objects.create(integer=3, integer2=4, string='3') + O2OCaseTestModel.objects.create(o2o=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=4) + + o = CaseTestModel.objects.create(integer=3, integer2=3, string='3') + O2OCaseTestModel.objects.create(o2o=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=3) + FKCaseTestModel.objects.create(fk=o, integer=4) + + o = CaseTestModel.objects.create(integer=4, integer2=5, string='4') + O2OCaseTestModel.objects.create(o2o=o, integer=1) + FKCaseTestModel.objects.create(fk=o, integer=5) + + # GROUP BY on Oracle fails with TextField/BinaryField; see #24096. + cls.non_lob_fields = [ + f.name for f in CaseTestModel._meta.get_fields() + if not (f.is_relation and f.auto_created) and not isinstance(f, (models.BinaryField, models.TextField)) + ] + + def test_annotate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(test=Case( + When(integer=1, then=Value('one')), + When(integer=2, then=Value('two')), + default=Value('other'), + output_field=models.CharField(), + )).order_by('pk'), + [(1, 'one'), (2, 'two'), (3, 'other'), (2, 'two'), (3, 'other'), (3, 'other'), (4, 'other')], + transform=attrgetter('integer', 'test') + ) + + def test_annotate_without_default(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(test=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + output_field=models.IntegerField(), + )).order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'test') + ) + + def test_annotate_with_expression_as_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(f_test=Case( + When(integer=1, then=F('integer') + 1), + When(integer=2, then=F('integer') + 3), + default='integer', + )).order_by('pk'), + [(1, 2), (2, 5), (3, 3), (2, 5), (3, 3), (3, 3), (4, 4)], + transform=attrgetter('integer', 'f_test') + ) + + def test_annotate_with_expression_as_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(f_test=Case( + When(integer2=F('integer'), then=Value('equal')), + When(integer2=F('integer') + 1, then=Value('+1')), + output_field=models.CharField(), + )).order_by('pk'), + [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')], + transform=attrgetter('integer', 'f_test') + ) + + def test_annotate_with_join_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(join_test=Case( + When(integer=1, then=F('o2o_rel__integer') + 1), + When(integer=2, then=F('o2o_rel__integer') + 3), + default='o2o_rel__integer', + )).order_by('pk'), + [(1, 2), (2, 5), (3, 3), (2, 5), (3, 3), (3, 3), (4, 1)], + transform=attrgetter('integer', 'join_test') + ) + + def test_annotate_with_join_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(join_test=Case( + When(integer2=F('o2o_rel__integer'), then=Value('equal')), + When(integer2=F('o2o_rel__integer') + 1, then=Value('+1')), + default=Value('other'), + output_field=models.CharField(), + )).order_by('pk'), + [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, 'other')], + transform=attrgetter('integer', 'join_test') + ) + + def test_annotate_with_join_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate(join_test=Case( + When(o2o_rel__integer=1, then=Value('one')), + When(o2o_rel__integer=2, then=Value('two')), + When(o2o_rel__integer=3, then=Value('three')), + default=Value('other'), + output_field=models.CharField(), + )).order_by('pk'), + [(1, 'one'), (2, 'two'), (3, 'three'), (2, 'two'), (3, 'three'), (3, 'three'), (4, 'one')], + transform=attrgetter('integer', 'join_test') + ) + + def test_annotate_with_annotation_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_plus_1=F('integer') + 1, + f_plus_3=F('integer') + 3, + ).annotate( + f_test=Case( + When(integer=1, then='f_plus_1'), + When(integer=2, then='f_plus_3'), + default='integer', + ), + ).order_by('pk'), + [(1, 2), (2, 5), (3, 3), (2, 5), (3, 3), (3, 3), (4, 4)], + transform=attrgetter('integer', 'f_test') + ) + + def test_annotate_with_annotation_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_plus_1=F('integer') + 1, + ).annotate( + f_test=Case( + When(integer2=F('integer'), then=Value('equal')), + When(integer2=F('f_plus_1'), then=Value('+1')), + output_field=models.CharField(), + ), + ).order_by('pk'), + [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')], + transform=attrgetter('integer', 'f_test') + ) + + def test_annotate_with_annotation_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_minus_2=F('integer') - 2, + ).annotate( + test=Case( + When(f_minus_2=-1, then=Value('negative one')), + When(f_minus_2=0, then=Value('zero')), + When(f_minus_2=1, then=Value('one')), + default=Value('other'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [(1, 'negative one'), (2, 'zero'), (3, 'one'), (2, 'zero'), (3, 'one'), (3, 'one'), (4, 'other')], + transform=attrgetter('integer', 'test') + ) + + def test_annotate_with_aggregation_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + min=Min('fk_rel__integer'), + max=Max('fk_rel__integer'), + ).annotate( + test=Case( + When(integer=2, then='min'), + When(integer=3, then='max'), + ), + ).order_by('pk'), + [(1, None, 1, 1), (2, 2, 2, 3), (3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4), (3, 4, 3, 4), (4, None, 5, 5)], + transform=itemgetter('integer', 'test', 'min', 'max') + ) + + def test_annotate_with_aggregation_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + min=Min('fk_rel__integer'), + max=Max('fk_rel__integer'), + ).annotate( + test=Case( + When(integer2=F('min'), then=Value('min')), + When(integer2=F('max'), then=Value('max')), + output_field=models.CharField(), + ), + ).order_by('pk'), + [(1, 1, 'min'), (2, 3, 'max'), (3, 4, 'max'), (2, 2, 'min'), (3, 4, 'max'), (3, 3, 'min'), (4, 5, 'min')], + transform=itemgetter('integer', 'integer2', 'test') + ) + + def test_annotate_with_aggregation_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + max=Max('fk_rel__integer'), + ).annotate( + test=Case( + When(max=3, then=Value('max = 3')), + When(max=4, then=Value('max = 4')), + default=Value(''), + output_field=models.CharField(), + ), + ).order_by('pk'), + [(1, 1, ''), (2, 3, 'max = 3'), (3, 4, 'max = 4'), (2, 3, 'max = 3'), + (3, 4, 'max = 4'), (3, 4, 'max = 4'), (4, 5, '')], + transform=itemgetter('integer', 'max', 'test') + ) + + def test_combined_expression(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + test=Case( + When(integer=1, then=Value(2)), + When(integer=2, then=Value(1)), + default=Value(3), + output_field=models.IntegerField(), + ) + 1, + ).order_by('pk'), + [(1, 3), (2, 2), (3, 4), (2, 2), (3, 4), (3, 4), (4, 4)], + transform=attrgetter('integer', 'test') + ) + + def test_in_subquery(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter( + pk__in=CaseTestModel.objects.annotate( + test=Case( + When(integer=F('integer2'), then='pk'), + When(integer=4, then='pk'), + output_field=models.IntegerField(), + ), + ).values('test')).order_by('pk'), + [(1, 1), (2, 2), (3, 3), (4, 5)], + transform=attrgetter('integer', 'integer2') + ) + + def test_aggregate(self): + self.assertEqual( + CaseTestModel.objects.aggregate( + one=models.Sum(Case( + When(integer=1, then=Value(1)), + output_field=models.IntegerField(), + )), + two=models.Sum(Case( + When(integer=2, then=Value(1)), + output_field=models.IntegerField(), + )), + three=models.Sum(Case( + When(integer=3, then=Value(1)), + output_field=models.IntegerField(), + )), + four=models.Sum(Case( + When(integer=4, then=Value(1)), + output_field=models.IntegerField(), + )), + ), + {'one': 1, 'two': 2, 'three': 3, 'four': 1} + ) + + def test_aggregate_with_expression_as_value(self): + self.assertEqual( + CaseTestModel.objects.aggregate( + one=models.Sum(Case(When(integer=1, then='integer'))), + two=models.Sum(Case(When(integer=2, then=F('integer') - 1))), + three=models.Sum(Case(When(integer=3, then=F('integer') + 1))), + ), + {'one': 1, 'two': 2, 'three': 12} + ) + + def test_aggregate_with_expression_as_condition(self): + self.assertEqual( + CaseTestModel.objects.aggregate( + equal=models.Sum(Case( + When(integer2=F('integer'), then=Value(1)), + output_field=models.IntegerField(), + )), + plus_one=models.Sum(Case( + When(integer2=F('integer') + 1, then=Value(1)), + output_field=models.IntegerField(), + )), + ), + {'equal': 3, 'plus_one': 4} + ) + + def test_filter(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(integer=2, then=Value(3)), + When(integer=3, then=Value(4)), + default=Value(1), + output_field=models.IntegerField(), + )).order_by('pk'), + [(1, 1), (2, 3), (3, 4), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_without_default(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(integer=2, then=Value(3)), + When(integer=3, then=Value(4)), + output_field=models.IntegerField(), + )).order_by('pk'), + [(2, 3), (3, 4), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_expression_as_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(integer=2, then=F('integer') + 1), + When(integer=3, then=F('integer')), + default='integer', + )).order_by('pk'), + [(1, 1), (2, 3), (3, 3)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_expression_as_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(string=Case( + When(integer2=F('integer'), then=Value('2')), + When(integer2=F('integer') + 1, then=Value('3')), + output_field=models.CharField(), + )).order_by('pk'), + [(3, 4, '3'), (2, 2, '2'), (3, 4, '3')], + transform=attrgetter('integer', 'integer2', 'string') + ) + + def test_filter_with_join_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(integer=2, then=F('o2o_rel__integer') + 1), + When(integer=3, then=F('o2o_rel__integer')), + default='o2o_rel__integer', + )).order_by('pk'), + [(1, 1), (2, 3), (3, 3)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_join_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer=Case( + When(integer2=F('o2o_rel__integer') + 1, then=Value(2)), + When(integer2=F('o2o_rel__integer'), then=Value(3)), + output_field=models.IntegerField(), + )).order_by('pk'), + [(2, 3), (3, 3)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_join_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.filter(integer2=Case( + When(o2o_rel__integer=1, then=Value(1)), + When(o2o_rel__integer=2, then=Value(3)), + When(o2o_rel__integer=3, then=Value(4)), + output_field=models.IntegerField(), + )).order_by('pk'), + [(1, 1), (2, 3), (3, 4), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_annotation_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f=F('integer'), + f_plus_1=F('integer') + 1, + ).filter( + integer2=Case( + When(integer=2, then='f_plus_1'), + When(integer=3, then='f'), + ), + ).order_by('pk'), + [(2, 3), (3, 3)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_annotation_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_plus_1=F('integer') + 1, + ).filter( + integer=Case( + When(integer2=F('integer'), then=Value(2)), + When(integer2=F('f_plus_1'), then=Value(3)), + output_field=models.IntegerField(), + ), + ).order_by('pk'), + [(3, 4), (2, 2), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_annotation_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + f_plus_1=F('integer') + 1, + ).filter( + integer2=Case( + When(f_plus_1=3, then=Value(3)), + When(f_plus_1=4, then=Value(4)), + default=Value(1), + output_field=models.IntegerField(), + ), + ).order_by('pk'), + [(1, 1), (2, 3), (3, 4), (3, 4)], + transform=attrgetter('integer', 'integer2') + ) + + def test_filter_with_aggregation_in_value(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + min=Min('fk_rel__integer'), + max=Max('fk_rel__integer'), + ).filter( + integer2=Case( + When(integer=2, then='min'), + When(integer=3, then='max'), + ), + ).order_by('pk'), + [(3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4)], + transform=itemgetter('integer', 'integer2', 'min', 'max') + ) + + def test_filter_with_aggregation_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + min=Min('fk_rel__integer'), + max=Max('fk_rel__integer'), + ).filter( + integer=Case( + When(integer2=F('min'), then=Value(2)), + When(integer2=F('max'), then=Value(3)), + ), + ).order_by('pk'), + [(3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4)], + transform=itemgetter('integer', 'integer2', 'min', 'max') + ) + + def test_filter_with_aggregation_in_predicate(self): + self.assertQuerysetEqual( + CaseTestModel.objects.values(*self.non_lob_fields).annotate( + max=Max('fk_rel__integer'), + ).filter( + integer=Case( + When(max=3, then=Value(2)), + When(max=4, then=Value(3)), + ), + ).order_by('pk'), + [(2, 3, 3), (3, 4, 4), (2, 2, 3), (3, 4, 4), (3, 3, 4)], + transform=itemgetter('integer', 'integer2', 'max') + ) + + def test_update(self): + CaseTestModel.objects.update( + string=Case( + When(integer=1, then=Value('one')), + When(integer=2, then=Value('two')), + default=Value('other'), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 'one'), (2, 'two'), (3, 'other'), (2, 'two'), (3, 'other'), (3, 'other'), (4, 'other')], + transform=attrgetter('integer', 'string') + ) + + def test_update_without_default(self): + CaseTestModel.objects.update( + integer2=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'integer2') + ) + + def test_update_with_expression_as_value(self): + CaseTestModel.objects.update( + integer=Case( + When(integer=1, then=F('integer') + 1), + When(integer=2, then=F('integer') + 3), + default='integer', + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [('1', 2), ('2', 5), ('3', 3), ('2', 5), ('3', 3), ('3', 3), ('4', 4)], + transform=attrgetter('string', 'integer') + ) + + def test_update_with_expression_as_condition(self): + CaseTestModel.objects.update( + string=Case( + When(integer2=F('integer'), then=Value('equal')), + When(integer2=F('integer') + 1, then=Value('+1')), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')], + transform=attrgetter('integer', 'string') + ) + + def test_update_with_join_in_condition_raise_field_error(self): + with self.assertRaisesMessage(FieldError, 'Joined field references are not permitted in this query'): + CaseTestModel.objects.update( + integer=Case( + When(integer2=F('o2o_rel__integer') + 1, then=Value(2)), + When(integer2=F('o2o_rel__integer'), then=Value(3)), + output_field=models.IntegerField(), + ), + ) + + def test_update_with_join_in_predicate_raise_field_error(self): + with self.assertRaisesMessage(FieldError, 'Joined field references are not permitted in this query'): + CaseTestModel.objects.update( + string=Case( + When(o2o_rel__integer=1, then=Value('one')), + When(o2o_rel__integer=2, then=Value('two')), + When(o2o_rel__integer=3, then=Value('three')), + default=Value('other'), + output_field=models.CharField(), + ), + ) + + def test_update_big_integer(self): + CaseTestModel.objects.update( + big_integer=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'big_integer') + ) + + def test_update_binary(self): + CaseTestModel.objects.update( + binary=Case( + # fails on postgresql on Python 2.7 if output_field is not + # set explicitly + When(integer=1, then=Value(b'one', output_field=models.BinaryField())), + When(integer=2, then=Value(b'two', output_field=models.BinaryField())), + default=Value(b'', output_field=models.BinaryField()), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, b'one'), (2, b'two'), (3, b''), (2, b'two'), (3, b''), (3, b''), (4, b'')], + transform=lambda o: (o.integer, six.binary_type(o.binary)) + ) + + def test_update_boolean(self): + CaseTestModel.objects.update( + boolean=Case( + When(integer=1, then=Value(True)), + When(integer=2, then=Value(True)), + default=Value(False), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, True), (2, True), (3, False), (2, True), (3, False), (3, False), (4, False)], + transform=attrgetter('integer', 'boolean') + ) + + def test_update_comma_separated_integer(self): + CaseTestModel.objects.update( + comma_separated_integer=Case( + When(integer=1, then=Value('1')), + When(integer=2, then=Value('2,2')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1'), (2, '2,2'), (3, ''), (2, '2,2'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'comma_separated_integer') + ) + + def test_update_date(self): + CaseTestModel.objects.update( + date=Case( + When(integer=1, then=Value(date(2015, 1, 1))), + When(integer=2, then=Value(date(2015, 1, 2))), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [ + (1, date(2015, 1, 1)), (2, date(2015, 1, 2)), (3, None), (2, date(2015, 1, 2)), + (3, None), (3, None), (4, None) + ], + transform=attrgetter('integer', 'date') + ) + + def test_update_date_time(self): + CaseTestModel.objects.update( + date_time=Case( + When(integer=1, then=Value(datetime(2015, 1, 1))), + When(integer=2, then=Value(datetime(2015, 1, 2))), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [ + (1, datetime(2015, 1, 1)), (2, datetime(2015, 1, 2)), (3, None), (2, datetime(2015, 1, 2)), + (3, None), (3, None), (4, None) + ], + transform=attrgetter('integer', 'date_time') + ) + + def test_update_decimal(self): + CaseTestModel.objects.update( + decimal=Case( + When(integer=1, then=Value(Decimal('1.1'))), + When(integer=2, then=Value(Decimal('2.2'))), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, Decimal('1.1')), (2, Decimal('2.2')), (3, None), (2, Decimal('2.2')), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'decimal') + ) + + def test_update_duration(self): + CaseTestModel.objects.update( + duration=Case( + # fails on sqlite if output_field is not set explicitly on all + # Values containing timedeltas + When(integer=1, then=Value(timedelta(1), output_field=models.DurationField())), + When(integer=2, then=Value(timedelta(2), output_field=models.DurationField())), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, timedelta(1)), (2, timedelta(2)), (3, None), (2, timedelta(2)), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'duration') + ) + + def test_update_email(self): + CaseTestModel.objects.update( + email=Case( + When(integer=1, then=Value('1@example.com')), + When(integer=2, then=Value('2@example.com')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1@example.com'), (2, '2@example.com'), (3, ''), (2, '2@example.com'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'email') + ) + + def test_update_file(self): + CaseTestModel.objects.update( + file=Case( + When(integer=1, then=Value('~/1')), + When(integer=2, then=Value('~/2')), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], + transform=lambda o: (o.integer, six.text_type(o.file)) + ) + + def test_update_file_path(self): + CaseTestModel.objects.update( + file_path=Case( + When(integer=1, then=Value('~/1')), + When(integer=2, then=Value('~/2')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'file_path') + ) + + def test_update_float(self): + CaseTestModel.objects.update( + float=Case( + When(integer=1, then=Value(1.1)), + When(integer=2, then=Value(2.2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1.1), (2, 2.2), (3, None), (2, 2.2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'float') + ) + + def test_update_image(self): + CaseTestModel.objects.update( + image=Case( + When(integer=1, then=Value('~/1')), + When(integer=2, then=Value('~/2')), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '~/1'), (2, '~/2'), (3, ''), (2, '~/2'), (3, ''), (3, ''), (4, '')], + transform=lambda o: (o.integer, six.text_type(o.image)) + ) + + def test_update_ip_address(self): + CaseTestModel.objects.update( + ip_address=Case( + # fails on postgresql if output_field is not set explicitly + When(integer=1, then=Value('1.1.1.1')), + When(integer=2, then=Value('2.2.2.2')), + output_field=models.IPAddressField(), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1.1.1.1'), (2, '2.2.2.2'), (3, None), (2, '2.2.2.2'), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'ip_address') + ) + + def test_update_generic_ip_address(self): + CaseTestModel.objects.update( + generic_ip_address=Case( + # fails on postgresql if output_field is not set explicitly + When(integer=1, then=Value('1.1.1.1')), + When(integer=2, then=Value('2.2.2.2')), + output_field=models.GenericIPAddressField(), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1.1.1.1'), (2, '2.2.2.2'), (3, None), (2, '2.2.2.2'), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'generic_ip_address') + ) + + def test_update_null_boolean(self): + CaseTestModel.objects.update( + null_boolean=Case( + When(integer=1, then=Value(True)), + When(integer=2, then=Value(False)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, True), (2, False), (3, None), (2, False), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'null_boolean') + ) + + def test_update_positive_integer(self): + CaseTestModel.objects.update( + positive_integer=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'positive_integer') + ) + + def test_update_positive_small_integer(self): + CaseTestModel.objects.update( + positive_small_integer=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'positive_small_integer') + ) + + def test_update_slug(self): + CaseTestModel.objects.update( + slug=Case( + When(integer=1, then=Value('1')), + When(integer=2, then=Value('2')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1'), (2, '2'), (3, ''), (2, '2'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'slug') + ) + + def test_update_small_integer(self): + CaseTestModel.objects.update( + small_integer=Case( + When(integer=1, then=Value(1)), + When(integer=2, then=Value(2)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'small_integer') + ) + + def test_update_string(self): + CaseTestModel.objects.filter(string__in=['1', '2']).update( + string=Case( + When(integer=1, then=Value('1', output_field=models.CharField())), + When(integer=2, then=Value('2', output_field=models.CharField())), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.filter(string__in=['1', '2']).order_by('pk'), + [(1, '1'), (2, '2'), (2, '2')], + transform=attrgetter('integer', 'string') + ) + + def test_update_text(self): + CaseTestModel.objects.update( + text=Case( + When(integer=1, then=Value('1')), + When(integer=2, then=Value('2')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, '1'), (2, '2'), (3, ''), (2, '2'), (3, ''), (3, ''), (4, '')], + transform=attrgetter('integer', 'text') + ) + + def test_update_time(self): + CaseTestModel.objects.update( + time=Case( + # fails on sqlite if output_field is not set explicitly on all + # Values containing times + When(integer=1, then=Value(time(1), output_field=models.TimeField())), + When(integer=2, then=Value(time(2), output_field=models.TimeField())), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, time(1)), (2, time(2)), (3, None), (2, time(2)), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'time') + ) + + def test_update_url(self): + CaseTestModel.objects.update( + url=Case( + When(integer=1, then=Value('http://1.example.com/')), + When(integer=2, then=Value('http://2.example.com/')), + default=Value(''), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [ + (1, 'http://1.example.com/'), (2, 'http://2.example.com/'), (3, ''), (2, 'http://2.example.com/'), + (3, ''), (3, ''), (4, '') + ], + transform=attrgetter('integer', 'url') + ) + + def test_update_uuid(self): + CaseTestModel.objects.update( + uuid=Case( + # fails on sqlite if output_field is not set explicitly on all + # Values containing UUIDs + When(integer=1, then=Value( + UUID('11111111111111111111111111111111'), + output_field=models.UUIDField(), + )), + When(integer=2, then=Value( + UUID('22222222222222222222222222222222'), + output_field=models.UUIDField(), + )), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [ + (1, UUID('11111111111111111111111111111111')), (2, UUID('22222222222222222222222222222222')), (3, None), + (2, UUID('22222222222222222222222222222222')), (3, None), (3, None), (4, None) + ], + transform=attrgetter('integer', 'uuid') + ) + + def test_update_fk(self): + obj1, obj2 = CaseTestModel.objects.all()[:2] + + CaseTestModel.objects.update( + fk=Case( + When(integer=1, then=Value(obj1.pk)), + When(integer=2, then=Value(obj2.pk)), + ), + ) + self.assertQuerysetEqual( + CaseTestModel.objects.all().order_by('pk'), + [(1, obj1.pk), (2, obj2.pk), (3, None), (2, obj2.pk), (3, None), (3, None), (4, None)], + transform=attrgetter('integer', 'fk_id') + ) + + def test_lookup_in_condition(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + test=Case( + When(integer__lt=2, then=Value('less than 2')), + When(integer__gt=2, then=Value('greater than 2')), + default=Value('equal to 2'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [ + (1, 'less than 2'), (2, 'equal to 2'), (3, 'greater than 2'), (2, 'equal to 2'), (3, 'greater than 2'), + (3, 'greater than 2'), (4, 'greater than 2') + ], + transform=attrgetter('integer', 'test') + ) + + def test_lookup_different_fields(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + test=Case( + When(integer=2, integer2=3, then=Value('when')), + default=Value('default'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [ + (1, 1, 'default'), (2, 3, 'when'), (3, 4, 'default'), (2, 2, 'default'), (3, 4, 'default'), + (3, 3, 'default'), (4, 5, 'default') + ], + transform=attrgetter('integer', 'integer2', 'test') + ) + + def test_combined_q_object(self): + self.assertQuerysetEqual( + CaseTestModel.objects.annotate( + test=Case( + When(Q(integer=2) | Q(integer2=3), then=Value('when')), + default=Value('default'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [ + (1, 1, 'default'), (2, 3, 'when'), (3, 4, 'default'), (2, 2, 'when'), (3, 4, 'default'), + (3, 3, 'when'), (4, 5, 'default') + ], + transform=attrgetter('integer', 'integer2', 'test') + ) + + +class CaseDocumentationExamples(TestCase): + @classmethod + def setUpTestData(cls): + Client.objects.create( + name='Jane Doe', + account_type=Client.REGULAR, + registered_on=date.today() - timedelta(days=36), + ) + Client.objects.create( + name='James Smith', + account_type=Client.GOLD, + registered_on=date.today() - timedelta(days=5), + ) + Client.objects.create( + name='Jack Black', + account_type=Client.PLATINUM, + registered_on=date.today() - timedelta(days=10 * 365), + ) + + def test_simple_example(self): + self.assertQuerysetEqual( + Client.objects.annotate( + discount=Case( + When(account_type=Client.GOLD, then=Value('5%')), + When(account_type=Client.PLATINUM, then=Value('10%')), + default=Value('0%'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [('Jane Doe', '0%'), ('James Smith', '5%'), ('Jack Black', '10%')], + transform=attrgetter('name', 'discount') + ) + + def test_lookup_example(self): + a_month_ago = date.today() - timedelta(days=30) + a_year_ago = date.today() - timedelta(days=365) + self.assertQuerysetEqual( + Client.objects.annotate( + discount=Case( + When(registered_on__lte=a_year_ago, then=Value('10%')), + When(registered_on__lte=a_month_ago, then=Value('5%')), + default=Value('0%'), + output_field=models.CharField(), + ), + ).order_by('pk'), + [('Jane Doe', '5%'), ('James Smith', '0%'), ('Jack Black', '10%')], + transform=attrgetter('name', 'discount') + ) + + def test_conditional_update_example(self): + a_month_ago = date.today() - timedelta(days=30) + a_year_ago = date.today() - timedelta(days=365) + Client.objects.update( + account_type=Case( + When(registered_on__lte=a_year_ago, then=Value(Client.PLATINUM)), + When(registered_on__lte=a_month_ago, then=Value(Client.GOLD)), + default=Value(Client.REGULAR), + ), + ) + self.assertQuerysetEqual( + Client.objects.all().order_by('pk'), + [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')], + transform=attrgetter('name', 'account_type') + ) + + def test_conditional_aggregation_example(self): + Client.objects.create( + name='Jean Grey', + account_type=Client.REGULAR, + registered_on=date.today(), + ) + Client.objects.create( + name='James Bond', + account_type=Client.PLATINUM, + registered_on=date.today(), + ) + Client.objects.create( + name='Jane Porter', + account_type=Client.PLATINUM, + registered_on=date.today(), + ) + self.assertEqual( + Client.objects.aggregate( + regular=models.Sum(Case( + When(account_type=Client.REGULAR, then=Value(1)), + output_field=models.IntegerField(), + )), + gold=models.Sum(Case( + When(account_type=Client.GOLD, then=Value(1)), + output_field=models.IntegerField(), + )), + platinum=models.Sum(Case( + When(account_type=Client.PLATINUM, then=Value(1)), + output_field=models.IntegerField(), + )), + ), + {'regular': 2, 'gold': 1, 'platinum': 3} + ) |