diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/expressions/tests.py | 56 | ||||
-rw-r--r-- | tests/queries/test_q.py | 4 | ||||
-rw-r--r-- | tests/update/models.py | 1 | ||||
-rw-r--r-- | tests/update/tests.py | 30 |
4 files changed, 87 insertions, 4 deletions
diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index fd0094db63..465edc54b5 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -48,6 +48,7 @@ from django.db.models.expressions import ( Col, Combinable, CombinedExpression, + NegatedExpression, RawSQL, Ref, ) @@ -2536,6 +2537,61 @@ class ExpressionWrapperTests(SimpleTestCase): self.assertEqual(group_by_cols[0].output_field, expr.output_field) +class NegatedExpressionTests(TestCase): + @classmethod + def setUpTestData(cls): + ceo = Employee.objects.create(firstname="Joe", lastname="Smith", salary=10) + cls.eu_company = Company.objects.create( + name="Example Inc.", + num_employees=2300, + num_chairs=5, + ceo=ceo, + based_in_eu=True, + ) + cls.non_eu_company = Company.objects.create( + name="Foobar Ltd.", + num_employees=3, + num_chairs=4, + ceo=ceo, + based_in_eu=False, + ) + + def test_invert(self): + f = F("field") + self.assertEqual(~f, NegatedExpression(f)) + self.assertIsNot(~~f, f) + self.assertEqual(~~f, f) + + def test_filter(self): + self.assertSequenceEqual( + Company.objects.filter(~F("based_in_eu")), + [self.non_eu_company], + ) + + qs = Company.objects.annotate(eu_required=~Value(False)) + self.assertSequenceEqual( + qs.filter(based_in_eu=F("eu_required")).order_by("eu_required"), + [self.eu_company], + ) + self.assertSequenceEqual( + qs.filter(based_in_eu=~~F("eu_required")), + [self.eu_company], + ) + self.assertSequenceEqual( + qs.filter(based_in_eu=~F("eu_required")), + [self.non_eu_company], + ) + self.assertSequenceEqual(qs.filter(based_in_eu=~F("based_in_eu")), []) + + def test_values(self): + self.assertSequenceEqual( + Company.objects.annotate(negated=~F("based_in_eu")) + .values_list("name", "negated") + .order_by("name"), + [("Example Inc.", False), ("Foobar Ltd.", True)], + ) + + class OrderByTests(SimpleTestCase): def test_equal(self): self.assertEqual( diff --git a/tests/queries/test_q.py b/tests/queries/test_q.py index 923846b5a3..cdf40292b0 100644 --- a/tests/queries/test_q.py +++ b/tests/queries/test_q.py @@ -8,7 +8,7 @@ from django.db.models import ( Q, Value, ) -from django.db.models.expressions import RawSQL +from django.db.models.expressions import NegatedExpression, RawSQL from django.db.models.functions import Lower from django.db.models.sql.where import NothingNode from django.test import SimpleTestCase, TestCase @@ -87,7 +87,7 @@ class QTests(SimpleTestCase): ] for q in tests: with self.subTest(q=q): - self.assertIs(q.negated, True) + self.assertIsInstance(q, NegatedExpression) def test_deconstruct(self): q = Q(price__gt=F("discounted_price")) diff --git a/tests/update/models.py b/tests/update/models.py index d7452dc302..d71fc887c7 100644 --- a/tests/update/models.py +++ b/tests/update/models.py @@ -10,6 +10,7 @@ class DataPoint(models.Model): name = models.CharField(max_length=20) value = models.CharField(max_length=20) another_value = models.CharField(max_length=20, blank=True) + is_active = models.BooleanField(default=True) class RelatedPoint(models.Model): diff --git a/tests/update/tests.py b/tests/update/tests.py index 2162f5164d..e88eeda96d 100644 --- a/tests/update/tests.py +++ b/tests/update/tests.py @@ -2,7 +2,7 @@ import unittest from django.core.exceptions import FieldError from django.db import IntegrityError, connection, transaction -from django.db.models import CharField, Count, F, IntegerField, Max +from django.db.models import Case, CharField, Count, F, IntegerField, Max, When from django.db.models.functions import Abs, Concat, Lower from django.test import TestCase from django.test.utils import register_lookup @@ -81,7 +81,7 @@ class AdvancedTests(TestCase): def setUpTestData(cls): cls.d0 = DataPoint.objects.create(name="d0", value="apple") cls.d2 = DataPoint.objects.create(name="d2", value="banana") - cls.d3 = DataPoint.objects.create(name="d3", value="banana") + cls.d3 = DataPoint.objects.create(name="d3", value="banana", is_active=False) cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3) def test_update(self): @@ -249,6 +249,32 @@ class AdvancedTests(TestCase): Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3) self.assertEqual(Bar.objects.get().x, 3) + def test_update_negated_f(self): + DataPoint.objects.update(is_active=~F("is_active")) + self.assertCountEqual( + DataPoint.objects.values_list("name", "is_active"), + [("d0", False), ("d2", False), ("d3", True)], + ) + DataPoint.objects.update(is_active=~F("is_active")) + self.assertCountEqual( + DataPoint.objects.values_list("name", "is_active"), + [("d0", True), ("d2", True), ("d3", False)], + ) + + def test_update_negated_f_conditional_annotation(self): + DataPoint.objects.annotate( + is_d2=Case(When(name="d2", then=True), default=False) + ).update(is_active=~F("is_d2")) + self.assertCountEqual( + DataPoint.objects.values_list("name", "is_active"), + [("d0", True), ("d2", False), ("d3", True)], + ) + + def test_updating_non_conditional_field(self): + msg = "Cannot negate non-conditional expressions." + with self.assertRaisesMessage(TypeError, msg): + DataPoint.objects.update(is_active=~F("name")) + @unittest.skipUnless( connection.vendor == "mysql", |