summaryrefslogtreecommitdiff
path: root/tests/expressions_window
diff options
context:
space:
mode:
authorSimon Charette <charette.s@gmail.com>2022-08-10 08:22:01 -0400
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2022-08-15 08:26:26 +0200
commitf387d024fc75569d2a4a338bfda76cc2f328f627 (patch)
tree61994be69d4dfa545158f7e887f0673ec281e479 /tests/expressions_window
parentf3f9d03edf17ccfa17263c7efa0b1350d1ac9278 (diff)
downloaddjango-f387d024fc75569d2a4a338bfda76cc2f328f627.tar.gz
Refs #28333 -- Added partial support for filtering against window functions.
Adds support for joint predicates against window annotations through subquery wrapping while maintaining errors for disjointed filter attempts. The "qualify" wording was used to refer to predicates against window annotations as it's the name of a specialized Snowflake extension to SQL that is to window functions what HAVING is to aggregates. While not complete the implementation should cover most of the common use cases for filtering against window functions without requiring the complex subquery pushdown and predicate re-aliasing machinery to deal with disjointed predicates against columns, aggregates, and window functions. A complete disjointed filtering implementation should likely be deferred until proper QUALIFY support lands or the ORM gains a proper subquery pushdown interface.
Diffstat (limited to 'tests/expressions_window')
-rw-r--r--tests/expressions_window/models.py7
-rw-r--r--tests/expressions_window/tests.py310
2 files changed, 287 insertions, 30 deletions
diff --git a/tests/expressions_window/models.py b/tests/expressions_window/models.py
index 631e876e15..cf324ea8f6 100644
--- a/tests/expressions_window/models.py
+++ b/tests/expressions_window/models.py
@@ -17,6 +17,13 @@ class Employee(models.Model):
bonus = models.DecimalField(decimal_places=2, max_digits=15, null=True)
+class PastEmployeeDepartment(models.Model):
+ employee = models.ForeignKey(
+ Employee, related_name="past_departments", on_delete=models.CASCADE
+ )
+ department = models.CharField(max_length=40, blank=False, null=False)
+
+
class Detail(models.Model):
value = models.JSONField()
diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py
index 15f8a4d6b2..a71a3f947d 100644
--- a/tests/expressions_window/tests.py
+++ b/tests/expressions_window/tests.py
@@ -6,10 +6,9 @@ from django.core.exceptions import FieldError
from django.db import NotSupportedError, connection
from django.db.models import (
Avg,
- BooleanField,
Case,
+ Count,
F,
- Func,
IntegerField,
Max,
Min,
@@ -41,15 +40,17 @@ from django.db.models.functions import (
RowNumber,
Upper,
)
+from django.db.models.lookups import Exact
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
-from .models import Detail, Employee
+from .models import Classification, Detail, Employee, PastEmployeeDepartment
@skipUnlessDBFeature("supports_over_clause")
class WindowFunctionTests(TestCase):
@classmethod
def setUpTestData(cls):
+ classification = Classification.objects.create()
Employee.objects.bulk_create(
[
Employee(
@@ -59,6 +60,7 @@ class WindowFunctionTests(TestCase):
hire_date=e[3],
age=e[4],
bonus=Decimal(e[1]) / 400,
+ classification=classification,
)
for e in [
("Jones", 45000, "Accounting", datetime.datetime(2005, 11, 1), 20),
@@ -82,6 +84,13 @@ class WindowFunctionTests(TestCase):
]
]
)
+ employees = list(Employee.objects.order_by("pk"))
+ PastEmployeeDepartment.objects.bulk_create(
+ [
+ PastEmployeeDepartment(employee=employees[6], department="Sales"),
+ PastEmployeeDepartment(employee=employees[10], department="IT"),
+ ]
+ )
def test_dense_rank(self):
tests = [
@@ -902,6 +911,263 @@ class WindowFunctionTests(TestCase):
)
self.assertEqual(qs.count(), 12)
+ def test_filter(self):
+ qs = Employee.objects.annotate(
+ department_salary_rank=Window(
+ Rank(), partition_by="department", order_by="-salary"
+ ),
+ department_avg_age_diff=(
+ Window(Avg("age"), partition_by="department") - F("age")
+ ),
+ ).order_by("department", "name")
+ # Direct window reference.
+ self.assertQuerysetEqual(
+ qs.filter(department_salary_rank=1),
+ ["Adams", "Wilkinson", "Miller", "Johnson", "Smith"],
+ lambda employee: employee.name,
+ )
+ # Through a combined expression containing a window.
+ self.assertQuerysetEqual(
+ qs.filter(department_avg_age_diff__gt=0),
+ ["Jenson", "Jones", "Williams", "Miller", "Smith"],
+ lambda employee: employee.name,
+ )
+ # Intersection of multiple windows.
+ self.assertQuerysetEqual(
+ qs.filter(department_salary_rank=1, department_avg_age_diff__gt=0),
+ ["Miller"],
+ lambda employee: employee.name,
+ )
+ # Union of multiple windows.
+ self.assertQuerysetEqual(
+ qs.filter(Q(department_salary_rank=1) | Q(department_avg_age_diff__gt=0)),
+ [
+ "Adams",
+ "Jenson",
+ "Jones",
+ "Williams",
+ "Wilkinson",
+ "Miller",
+ "Johnson",
+ "Smith",
+ "Smith",
+ ],
+ lambda employee: employee.name,
+ )
+
+ def test_filter_conditional_annotation(self):
+ qs = (
+ Employee.objects.annotate(
+ rank=Window(Rank(), partition_by="department", order_by="-salary"),
+ case_first_rank=Case(
+ When(rank=1, then=True),
+ default=False,
+ ),
+ q_first_rank=Q(rank=1),
+ )
+ .order_by("name")
+ .values_list("name", flat=True)
+ )
+ for annotation in ["case_first_rank", "q_first_rank"]:
+ with self.subTest(annotation=annotation):
+ self.assertSequenceEqual(
+ qs.filter(**{annotation: True}),
+ ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"],
+ )
+
+ def test_filter_conditional_expression(self):
+ qs = (
+ Employee.objects.filter(
+ Exact(Window(Rank(), partition_by="department", order_by="-salary"), 1)
+ )
+ .order_by("name")
+ .values_list("name", flat=True)
+ )
+ self.assertSequenceEqual(
+ qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"]
+ )
+
+ def test_filter_column_ref_rhs(self):
+ qs = (
+ Employee.objects.annotate(
+ max_dept_salary=Window(Max("salary"), partition_by="department")
+ )
+ .filter(max_dept_salary=F("salary"))
+ .order_by("name")
+ .values_list("name", flat=True)
+ )
+ self.assertSequenceEqual(
+ qs, ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"]
+ )
+
+ def test_filter_values(self):
+ qs = (
+ Employee.objects.annotate(
+ department_salary_rank=Window(
+ Rank(), partition_by="department", order_by="-salary"
+ ),
+ )
+ .order_by("department", "name")
+ .values_list(Upper("name"), flat=True)
+ )
+ self.assertSequenceEqual(
+ qs.filter(department_salary_rank=1),
+ ["ADAMS", "WILKINSON", "MILLER", "JOHNSON", "SMITH"],
+ )
+
+ def test_filter_alias(self):
+ qs = Employee.objects.alias(
+ department_avg_age_diff=(
+ Window(Avg("age"), partition_by="department") - F("age")
+ ),
+ ).order_by("department", "name")
+ self.assertQuerysetEqual(
+ qs.filter(department_avg_age_diff__gt=0),
+ ["Jenson", "Jones", "Williams", "Miller", "Smith"],
+ lambda employee: employee.name,
+ )
+
+ def test_filter_select_related(self):
+ qs = (
+ Employee.objects.alias(
+ department_avg_age_diff=(
+ Window(Avg("age"), partition_by="department") - F("age")
+ ),
+ )
+ .select_related("classification")
+ .filter(department_avg_age_diff__gt=0)
+ .order_by("department", "name")
+ )
+ self.assertQuerysetEqual(
+ qs,
+ ["Jenson", "Jones", "Williams", "Miller", "Smith"],
+ lambda employee: employee.name,
+ )
+ with self.assertNumQueries(0):
+ qs[0].classification
+
+ def test_exclude(self):
+ qs = Employee.objects.annotate(
+ department_salary_rank=Window(
+ Rank(), partition_by="department", order_by="-salary"
+ ),
+ department_avg_age_diff=(
+ Window(Avg("age"), partition_by="department") - F("age")
+ ),
+ ).order_by("department", "name")
+ # Direct window reference.
+ self.assertQuerysetEqual(
+ qs.exclude(department_salary_rank__gt=1),
+ ["Adams", "Wilkinson", "Miller", "Johnson", "Smith"],
+ lambda employee: employee.name,
+ )
+ # Through a combined expression containing a window.
+ self.assertQuerysetEqual(
+ qs.exclude(department_avg_age_diff__lte=0),
+ ["Jenson", "Jones", "Williams", "Miller", "Smith"],
+ lambda employee: employee.name,
+ )
+ # Union of multiple windows.
+ self.assertQuerysetEqual(
+ qs.exclude(
+ Q(department_salary_rank__gt=1) | Q(department_avg_age_diff__lte=0)
+ ),
+ ["Miller"],
+ lambda employee: employee.name,
+ )
+ # Intersection of multiple windows.
+ self.assertQuerysetEqual(
+ qs.exclude(department_salary_rank__gt=1, department_avg_age_diff__lte=0),
+ [
+ "Adams",
+ "Jenson",
+ "Jones",
+ "Williams",
+ "Wilkinson",
+ "Miller",
+ "Johnson",
+ "Smith",
+ "Smith",
+ ],
+ lambda employee: employee.name,
+ )
+
+ def test_heterogeneous_filter(self):
+ qs = (
+ Employee.objects.annotate(
+ department_salary_rank=Window(
+ Rank(), partition_by="department", order_by="-salary"
+ ),
+ )
+ .order_by("name")
+ .values_list("name", flat=True)
+ )
+ # Heterogeneous filter between window function and aggregates pushes
+ # the WHERE clause to the QUALIFY outer query.
+ self.assertSequenceEqual(
+ qs.filter(
+ department_salary_rank=1, department__in=["Accounting", "Management"]
+ ),
+ ["Adams", "Miller"],
+ )
+ self.assertSequenceEqual(
+ qs.filter(
+ Q(department_salary_rank=1)
+ | Q(department__in=["Accounting", "Management"])
+ ),
+ [
+ "Adams",
+ "Jenson",
+ "Johnson",
+ "Johnson",
+ "Jones",
+ "Miller",
+ "Smith",
+ "Wilkinson",
+ "Williams",
+ ],
+ )
+ # Heterogeneous filter between window function and aggregates pushes
+ # the HAVING clause to the QUALIFY outer query.
+ qs = qs.annotate(past_department_count=Count("past_departments"))
+ self.assertSequenceEqual(
+ qs.filter(department_salary_rank=1, past_department_count__gte=1),
+ ["Johnson", "Miller"],
+ )
+ self.assertSequenceEqual(
+ qs.filter(Q(department_salary_rank=1) | Q(past_department_count__gte=1)),
+ ["Adams", "Johnson", "Miller", "Smith", "Wilkinson"],
+ )
+
+ def test_limited_filter(self):
+ """
+ A query filtering against a window function have its limit applied
+ after window filtering takes place.
+ """
+ self.assertQuerysetEqual(
+ Employee.objects.annotate(
+ department_salary_rank=Window(
+ Rank(), partition_by="department", order_by="-salary"
+ )
+ )
+ .filter(department_salary_rank=1)
+ .order_by("department")[0:3],
+ ["Adams", "Wilkinson", "Miller"],
+ lambda employee: employee.name,
+ )
+
+ def test_filter_count(self):
+ self.assertEqual(
+ Employee.objects.annotate(
+ department_salary_rank=Window(
+ Rank(), partition_by="department", order_by="-salary"
+ )
+ )
+ .filter(department_salary_rank=1)
+ .count(),
+ 5,
+ )
+
@skipUnlessDBFeature("supports_frame_range_fixed_distance")
def test_range_n_preceding_and_following(self):
qs = Employee.objects.annotate(
@@ -1071,6 +1337,7 @@ class WindowFunctionTests(TestCase):
),
year=ExtractYear("hire_date"),
)
+ .filter(sum__gte=45000)
.values("year", "sum")
.distinct("year")
.order_by("year")
@@ -1081,7 +1348,6 @@ class WindowFunctionTests(TestCase):
{"year": 2008, "sum": 45000},
{"year": 2009, "sum": 128000},
{"year": 2011, "sum": 60000},
- {"year": 2012, "sum": 40000},
{"year": 2013, "sum": 84000},
]
for idx, val in zip(range(len(results)), results):
@@ -1348,34 +1614,18 @@ class NonQueryWindowTests(SimpleTestCase):
frame.window_frame_start_end(None, None, None)
def test_invalid_filter(self):
- msg = "Window is disallowed in the filter clause"
- qs = Employee.objects.annotate(dense_rank=Window(expression=DenseRank()))
- with self.assertRaisesMessage(NotSupportedError, msg):
- qs.filter(dense_rank__gte=1)
- with self.assertRaisesMessage(NotSupportedError, msg):
- qs.annotate(inc_rank=F("dense_rank") + Value(1)).filter(inc_rank__gte=1)
- with self.assertRaisesMessage(NotSupportedError, msg):
- qs.filter(id=F("dense_rank"))
- with self.assertRaisesMessage(NotSupportedError, msg):
- qs.filter(id=Func("dense_rank", 2, function="div"))
- with self.assertRaisesMessage(NotSupportedError, msg):
- qs.annotate(total=Sum("dense_rank", filter=Q(name="Jones"))).filter(total=1)
-
- def test_conditional_annotation(self):
+ msg = (
+ "Heterogeneous disjunctive predicates against window functions are not "
+ "implemented when performing conditional aggregation."
+ )
qs = Employee.objects.annotate(
- dense_rank=Window(expression=DenseRank()),
- ).annotate(
- equal=Case(
- When(id=F("dense_rank"), then=Value(True)),
- default=Value(False),
- output_field=BooleanField(),
- ),
+ window=Window(Rank()),
+ past_dept_cnt=Count("past_departments"),
)
- # The SQL standard disallows referencing window functions in the WHERE
- # clause.
- msg = "Window is disallowed in the filter clause"
- with self.assertRaisesMessage(NotSupportedError, msg):
- qs.filter(equal=True)
+ with self.assertRaisesMessage(NotImplementedError, msg):
+ list(qs.filter(Q(window=1) | Q(department="Accounting")))
+ with self.assertRaisesMessage(NotImplementedError, msg):
+ list(qs.exclude(window=1, department="Accounting"))
def test_invalid_order_by(self):
msg = (