summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/expressions/tests.py56
-rw-r--r--tests/queries/test_q.py4
-rw-r--r--tests/update/models.py1
-rw-r--r--tests/update/tests.py30
4 files changed, 87 insertions, 4 deletions
diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py
index fd0094db63..465edc54b5 100644
--- a/tests/expressions/tests.py
+++ b/tests/expressions/tests.py
@@ -48,6 +48,7 @@ from django.db.models.expressions import (
Col,
Combinable,
CombinedExpression,
+ NegatedExpression,
RawSQL,
Ref,
)
@@ -2536,6 +2537,61 @@ class ExpressionWrapperTests(SimpleTestCase):
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
+class NegatedExpressionTests(TestCase):
+ @classmethod
+ def setUpTestData(cls):
+ ceo = Employee.objects.create(firstname="Joe", lastname="Smith", salary=10)
+ cls.eu_company = Company.objects.create(
+ name="Example Inc.",
+ num_employees=2300,
+ num_chairs=5,
+ ceo=ceo,
+ based_in_eu=True,
+ )
+ cls.non_eu_company = Company.objects.create(
+ name="Foobar Ltd.",
+ num_employees=3,
+ num_chairs=4,
+ ceo=ceo,
+ based_in_eu=False,
+ )
+
+ def test_invert(self):
+ f = F("field")
+ self.assertEqual(~f, NegatedExpression(f))
+ self.assertIsNot(~~f, f)
+ self.assertEqual(~~f, f)
+
+ def test_filter(self):
+ self.assertSequenceEqual(
+ Company.objects.filter(~F("based_in_eu")),
+ [self.non_eu_company],
+ )
+
+ qs = Company.objects.annotate(eu_required=~Value(False))
+ self.assertSequenceEqual(
+ qs.filter(based_in_eu=F("eu_required")).order_by("eu_required"),
+ [self.eu_company],
+ )
+ self.assertSequenceEqual(
+ qs.filter(based_in_eu=~~F("eu_required")),
+ [self.eu_company],
+ )
+ self.assertSequenceEqual(
+ qs.filter(based_in_eu=~F("eu_required")),
+ [self.non_eu_company],
+ )
+ self.assertSequenceEqual(qs.filter(based_in_eu=~F("based_in_eu")), [])
+
+ def test_values(self):
+ self.assertSequenceEqual(
+ Company.objects.annotate(negated=~F("based_in_eu"))
+ .values_list("name", "negated")
+ .order_by("name"),
+ [("Example Inc.", False), ("Foobar Ltd.", True)],
+ )
+
+
class OrderByTests(SimpleTestCase):
def test_equal(self):
self.assertEqual(
diff --git a/tests/queries/test_q.py b/tests/queries/test_q.py
index 923846b5a3..cdf40292b0 100644
--- a/tests/queries/test_q.py
+++ b/tests/queries/test_q.py
@@ -8,7 +8,7 @@ from django.db.models import (
Q,
Value,
)
-from django.db.models.expressions import RawSQL
+from django.db.models.expressions import NegatedExpression, RawSQL
from django.db.models.functions import Lower
from django.db.models.sql.where import NothingNode
from django.test import SimpleTestCase, TestCase
@@ -87,7 +87,7 @@ class QTests(SimpleTestCase):
]
for q in tests:
with self.subTest(q=q):
- self.assertIs(q.negated, True)
+ self.assertIsInstance(q, NegatedExpression)
def test_deconstruct(self):
q = Q(price__gt=F("discounted_price"))
diff --git a/tests/update/models.py b/tests/update/models.py
index d7452dc302..d71fc887c7 100644
--- a/tests/update/models.py
+++ b/tests/update/models.py
@@ -10,6 +10,7 @@ class DataPoint(models.Model):
name = models.CharField(max_length=20)
value = models.CharField(max_length=20)
another_value = models.CharField(max_length=20, blank=True)
+ is_active = models.BooleanField(default=True)
class RelatedPoint(models.Model):
diff --git a/tests/update/tests.py b/tests/update/tests.py
index 2162f5164d..e88eeda96d 100644
--- a/tests/update/tests.py
+++ b/tests/update/tests.py
@@ -2,7 +2,7 @@ import unittest
from django.core.exceptions import FieldError
from django.db import IntegrityError, connection, transaction
-from django.db.models import CharField, Count, F, IntegerField, Max
+from django.db.models import Case, CharField, Count, F, IntegerField, Max, When
from django.db.models.functions import Abs, Concat, Lower
from django.test import TestCase
from django.test.utils import register_lookup
@@ -81,7 +81,7 @@ class AdvancedTests(TestCase):
def setUpTestData(cls):
cls.d0 = DataPoint.objects.create(name="d0", value="apple")
cls.d2 = DataPoint.objects.create(name="d2", value="banana")
- cls.d3 = DataPoint.objects.create(name="d3", value="banana")
+ cls.d3 = DataPoint.objects.create(name="d3", value="banana", is_active=False)
cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3)
def test_update(self):
@@ -249,6 +249,32 @@ class AdvancedTests(TestCase):
Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
self.assertEqual(Bar.objects.get().x, 3)
+ def test_update_negated_f(self):
+ DataPoint.objects.update(is_active=~F("is_active"))
+ self.assertCountEqual(
+ DataPoint.objects.values_list("name", "is_active"),
+ [("d0", False), ("d2", False), ("d3", True)],
+ )
+ DataPoint.objects.update(is_active=~F("is_active"))
+ self.assertCountEqual(
+ DataPoint.objects.values_list("name", "is_active"),
+ [("d0", True), ("d2", True), ("d3", False)],
+ )
+
+ def test_update_negated_f_conditional_annotation(self):
+ DataPoint.objects.annotate(
+ is_d2=Case(When(name="d2", then=True), default=False)
+ ).update(is_active=~F("is_d2"))
+ self.assertCountEqual(
+ DataPoint.objects.values_list("name", "is_active"),
+ [("d0", True), ("d2", False), ("d3", True)],
+ )
+
+ def test_updating_non_conditional_field(self):
+ msg = "Cannot negate non-conditional expressions."
+ with self.assertRaisesMessage(TypeError, msg):
+ DataPoint.objects.update(is_active=~F("name"))
+
@unittest.skipUnless(
connection.vendor == "mysql",