diff options
author | Bhuvnesh <bhuvnesh875@gmail.com> | 2022-11-04 00:41:19 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-03 20:11:19 +0100 |
commit | e580b891cb5ae31eb0571c88428afb9bf69e47f2 (patch) | |
tree | 6edae59e3e5c93a40355863e0e88bac98dae8a43 /tests/async | |
parent | 3dc9f3ac6960c83cd32058677eb0ddb5a5e5da43 (diff) | |
download | django-e580b891cb5ae31eb0571c88428afb9bf69e47f2.tar.gz |
Refs #33646 -- Moved tests of QuerySet async interface into async tests.
Diffstat (limited to 'tests/async')
-rw-r--r-- | tests/async/models.py | 6 | ||||
-rw-r--r-- | tests/async/test_async_queryset.py | 248 |
2 files changed, 254 insertions, 0 deletions
diff --git a/tests/async/models.py b/tests/async/models.py index 0fd606b07e..8cb051258c 100644 --- a/tests/async/models.py +++ b/tests/async/models.py @@ -1,5 +1,11 @@ from django.db import models +from django.utils import timezone + + +class RelatedModel(models.Model): + simple = models.ForeignKey("SimpleModel", models.CASCADE, null=True) class SimpleModel(models.Model): field = models.IntegerField() + created = models.DateTimeField(default=timezone.now) diff --git a/tests/async/test_async_queryset.py b/tests/async/test_async_queryset.py new file mode 100644 index 0000000000..253183ea10 --- /dev/null +++ b/tests/async/test_async_queryset.py @@ -0,0 +1,248 @@ +import json +import xml.etree.ElementTree +from datetime import datetime + +from asgiref.sync import async_to_sync, sync_to_async + +from django.db import NotSupportedError, connection +from django.db.models import Sum +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature + +from .models import SimpleModel + + +class AsyncQuerySetTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.s1 = SimpleModel.objects.create( + field=1, + created=datetime(2022, 1, 1, 0, 0, 0), + ) + cls.s2 = SimpleModel.objects.create( + field=2, + created=datetime(2022, 1, 1, 0, 0, 1), + ) + cls.s3 = SimpleModel.objects.create( + field=3, + created=datetime(2022, 1, 1, 0, 0, 2), + ) + + @staticmethod + def _get_db_feature(connection_, feature_name): + # Wrapper to avoid accessing connection attributes until inside + # coroutine function. Connection access is thread sensitive and cannot + # be passed across sync/async boundaries. + return getattr(connection_.features, feature_name) + + async def test_async_iteration(self): + results = [] + async for m in SimpleModel.objects.order_by("pk"): + results.append(m) + self.assertEqual(results, [self.s1, self.s2, self.s3]) + + async def test_aiterator(self): + qs = SimpleModel.objects.aiterator() + results = [] + async for m in qs: + results.append(m) + self.assertCountEqual(results, [self.s1, self.s2, self.s3]) + + async def test_aiterator_prefetch_related(self): + qs = SimpleModel.objects.prefetch_related("relatedmodels").aiterator() + msg = "Using QuerySet.aiterator() after prefetch_related() is not supported." + with self.assertRaisesMessage(NotSupportedError, msg): + async for m in qs: + pass + + async def test_aiterator_invalid_chunk_size(self): + msg = "Chunk size must be strictly positive." + for size in [0, -1]: + qs = SimpleModel.objects.aiterator(chunk_size=size) + with self.subTest(size=size), self.assertRaisesMessage(ValueError, msg): + async for m in qs: + pass + + async def test_acount(self): + count = await SimpleModel.objects.acount() + self.assertEqual(count, 3) + + async def test_acount_cached_result(self): + qs = SimpleModel.objects.all() + # Evaluate the queryset to populate the query cache. + [x async for x in qs] + count = await qs.acount() + self.assertEqual(count, 3) + + await sync_to_async(SimpleModel.objects.create)( + field=4, + created=datetime(2022, 1, 1, 0, 0, 0), + ) + # The query cache is used. + count = await qs.acount() + self.assertEqual(count, 3) + + async def test_aget(self): + instance = await SimpleModel.objects.aget(field=1) + self.assertEqual(instance, self.s1) + + async def test_acreate(self): + await SimpleModel.objects.acreate(field=4) + self.assertEqual(await SimpleModel.objects.acount(), 4) + + async def test_aget_or_create(self): + instance, created = await SimpleModel.objects.aget_or_create(field=4) + self.assertEqual(await SimpleModel.objects.acount(), 4) + self.assertIs(created, True) + + async def test_aupdate_or_create(self): + instance, created = await SimpleModel.objects.aupdate_or_create( + id=self.s1.id, defaults={"field": 2} + ) + self.assertEqual(instance, self.s1) + self.assertIs(created, False) + instance, created = await SimpleModel.objects.aupdate_or_create(field=4) + self.assertEqual(await SimpleModel.objects.acount(), 4) + self.assertIs(created, True) + + @skipUnlessDBFeature("has_bulk_insert") + @async_to_sync + async def test_abulk_create(self): + instances = [SimpleModel(field=i) for i in range(10)] + qs = await SimpleModel.objects.abulk_create(instances) + self.assertEqual(len(qs), 10) + + @skipUnlessDBFeature("has_bulk_insert", "supports_update_conflicts") + @skipIfDBFeature("supports_update_conflicts_with_target") + @async_to_sync + async def test_update_conflicts_unique_field_unsupported(self): + msg = ( + "This database backend does not support updating conflicts with specifying " + "unique fields that can trigger the upsert." + ) + with self.assertRaisesMessage(NotSupportedError, msg): + await SimpleModel.objects.abulk_create( + [SimpleModel(field=1), SimpleModel(field=2)], + update_conflicts=True, + update_fields=["field"], + unique_fields=["created"], + ) + + async def test_abulk_update(self): + instances = SimpleModel.objects.all() + async for instance in instances: + instance.field = instance.field * 10 + + await SimpleModel.objects.abulk_update(instances, ["field"]) + + qs = [(o.pk, o.field) async for o in SimpleModel.objects.all()] + self.assertCountEqual( + qs, + [(self.s1.pk, 10), (self.s2.pk, 20), (self.s3.pk, 30)], + ) + + async def test_ain_bulk(self): + res = await SimpleModel.objects.ain_bulk() + self.assertEqual( + res, + {self.s1.pk: self.s1, self.s2.pk: self.s2, self.s3.pk: self.s3}, + ) + + res = await SimpleModel.objects.ain_bulk([self.s2.pk]) + self.assertEqual(res, {self.s2.pk: self.s2}) + + res = await SimpleModel.objects.ain_bulk([self.s2.pk], field_name="id") + self.assertEqual(res, {self.s2.pk: self.s2}) + + async def test_alatest(self): + instance = await SimpleModel.objects.alatest("created") + self.assertEqual(instance, self.s3) + + instance = await SimpleModel.objects.alatest("-created") + self.assertEqual(instance, self.s1) + + async def test_aearliest(self): + instance = await SimpleModel.objects.aearliest("created") + self.assertEqual(instance, self.s1) + + instance = await SimpleModel.objects.aearliest("-created") + self.assertEqual(instance, self.s3) + + async def test_afirst(self): + instance = await SimpleModel.objects.afirst() + self.assertEqual(instance, self.s1) + + instance = await SimpleModel.objects.filter(field=4).afirst() + self.assertIsNone(instance) + + async def test_alast(self): + instance = await SimpleModel.objects.alast() + self.assertEqual(instance, self.s3) + + instance = await SimpleModel.objects.filter(field=4).alast() + self.assertIsNone(instance) + + async def test_aaggregate(self): + total = await SimpleModel.objects.aaggregate(total=Sum("field")) + self.assertEqual(total, {"total": 6}) + + async def test_aexists(self): + check = await SimpleModel.objects.filter(field=1).aexists() + self.assertIs(check, True) + + check = await SimpleModel.objects.filter(field=4).aexists() + self.assertIs(check, False) + + async def test_acontains(self): + check = await SimpleModel.objects.acontains(self.s1) + self.assertIs(check, True) + # Unsaved instances are not allowed, so use an ID known not to exist. + check = await SimpleModel.objects.acontains( + SimpleModel(id=self.s3.id + 1, field=4) + ) + self.assertIs(check, False) + + async def test_aupdate(self): + await SimpleModel.objects.aupdate(field=99) + qs = [o async for o in SimpleModel.objects.all()] + values = [instance.field for instance in qs] + self.assertEqual(set(values), {99}) + + async def test_adelete(self): + await SimpleModel.objects.filter(field=2).adelete() + qs = [o async for o in SimpleModel.objects.all()] + self.assertCountEqual(qs, [self.s1, self.s3]) + + @skipUnlessDBFeature("supports_explaining_query_execution") + @async_to_sync + async def test_aexplain(self): + supported_formats = await sync_to_async(self._get_db_feature)( + connection, "supported_explain_formats" + ) + all_formats = (None, *supported_formats) + for format_ in all_formats: + with self.subTest(format=format_): + # TODO: Check the captured query when async versions of + # self.assertNumQueries/CaptureQueriesContext context + # processors are available. + result = await SimpleModel.objects.filter(field=1).aexplain( + format=format_ + ) + self.assertIsInstance(result, str) + self.assertTrue(result) + if not format_: + continue + if format_.lower() == "xml": + try: + xml.etree.ElementTree.fromstring(result) + except xml.etree.ElementTree.ParseError as e: + self.fail(f"QuerySet.aexplain() result is not valid XML: {e}") + elif format_.lower() == "json": + try: + json.loads(result) + except json.JSONDecodeError as e: + self.fail(f"QuerySet.aexplain() result is not valid JSON: {e}") + + async def test_raw(self): + sql = "SELECT id, field FROM async_simplemodel WHERE created=%s" + qs = SimpleModel.objects.raw(sql, [self.s1.created]) + self.assertEqual([o async for o in qs], [self.s1]) |