summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Wobrock <david.wobrock@gmail.com>2019-12-28 22:42:46 +0100
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2019-12-31 10:35:43 +0100
commit2f565f84aca136d9cc4e4d061f3196ddf9358ab8 (patch)
treeb084853444500fa3eeaa8a1fe07df961015f46e4
parent7d44aeb388a33ac56353a658b41a6119e93931ff (diff)
downloaddjango-2f565f84aca136d9cc4e4d061f3196ddf9358ab8.tar.gz
Fixed #31097 -- Fixed crash of ArrayAgg and StringAgg with filter when used in Subquery.
-rw-r--r--AUTHORS1
-rw-r--r--django/contrib/postgres/aggregates/mixins.py11
-rw-r--r--tests/postgres_tests/test_aggregates.py46
3 files changed, 47 insertions, 11 deletions
diff --git a/AUTHORS b/AUTHORS
index 5b7d67d06c..4632c66a62 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -242,6 +242,7 @@ answer newbie questions, and generally made Django that much better:
David Sanders <dsanders11@ucsbalum.com>
David Schein
David Tulig <david.tulig@gmail.com>
+ David Wobrock <david.wobrock@gmail.com>
Davide Ceretti <dav.ceretti@gmail.com>
Deepak Thukral <deep.thukral@gmail.com>
Denis Kuzmichyov <kuzmichyov@gmail.com>
diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py
index 4625738beb..3a43ca1a63 100644
--- a/django/contrib/postgres/aggregates/mixins.py
+++ b/django/contrib/postgres/aggregates/mixins.py
@@ -40,16 +40,7 @@ class OrderableAggMixin:
return super().set_source_expressions(exprs[:self._get_ordering_expressions_index()])
def get_source_expressions(self):
- return self.source_expressions + self.ordering
-
- def get_source_fields(self):
- # Filter out fields contributed by the ordering expressions as
- # these should not be used to determine which the return type of the
- # expression.
- return [
- e._output_field_or_none
- for e in self.get_source_expressions()[:self._get_ordering_expressions_index()]
- ]
+ return super().get_source_expressions() + self.ordering
def _get_ordering_expressions_index(self):
"""Return the index at which the ordering expressions start."""
diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py
index 9a042388bd..af84f12e91 100644
--- a/tests/postgres_tests/test_aggregates.py
+++ b/tests/postgres_tests/test_aggregates.py
@@ -21,7 +21,7 @@ except ImportError:
class TestGeneralAggregate(PostgreSQLTestCase):
@classmethod
def setUpTestData(cls):
- AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0)
+ cls.agg1 = AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0)
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1)
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo4', integer_field=2)
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo3', integer_field=0)
@@ -249,6 +249,50 @@ class TestGeneralAggregate(PostgreSQLTestCase):
).order_by('char_field').values_list('char_field', 'agg')
self.assertEqual(list(values), expected_result)
+ def test_string_agg_array_agg_filter_in_subquery(self):
+ StatTestModel.objects.bulk_create([
+ StatTestModel(related_field=self.agg1, int1=0, int2=5),
+ StatTestModel(related_field=self.agg1, int1=1, int2=4),
+ StatTestModel(related_field=self.agg1, int1=2, int2=3),
+ ])
+ for aggregate, expected_result in (
+ (
+ ArrayAgg('stattestmodel__int1', filter=Q(stattestmodel__int2__gt=3)),
+ [('Foo1', [0, 1]), ('Foo2', None)],
+ ),
+ (
+ StringAgg(
+ Cast('stattestmodel__int2', CharField()),
+ delimiter=';',
+ filter=Q(stattestmodel__int1__lt=2),
+ ),
+ [('Foo1', '5;4'), ('Foo2', None)],
+ ),
+ ):
+ with self.subTest(aggregate=aggregate.__class__.__name__):
+ subquery = AggregateTestModel.objects.filter(
+ pk=OuterRef('pk'),
+ ).annotate(agg=aggregate).values('agg')
+ values = AggregateTestModel.objects.annotate(
+ agg=Subquery(subquery),
+ ).filter(
+ char_field__in=['Foo1', 'Foo2'],
+ ).order_by('char_field').values_list('char_field', 'agg')
+ self.assertEqual(list(values), expected_result)
+
+ def test_string_agg_filter_in_subquery_with_exclude(self):
+ subquery = AggregateTestModel.objects.annotate(
+ stringagg=StringAgg(
+ 'char_field',
+ delimiter=';',
+ filter=Q(char_field__endswith='1'),
+ )
+ ).exclude(stringagg='').values('id')
+ self.assertSequenceEqual(
+ AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
+ [self.agg1],
+ )
+
class TestAggregateDistinct(PostgreSQLTestCase):
@classmethod