summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Wobrock <david.wobrock@gmail.com>2023-04-18 10:19:06 +0200
committerMariusz Felisiak <felisiak.mariusz@gmail.com>2023-04-18 12:41:14 +0200
commit9bbf97bcdb488bb11aebb5bd405549fbec6852cd (patch)
treed679e1a72227b9b6b5b5f44af6a4e231228f2d0d
parent594fcc2b7427f7baf2cf1a2d7cd2be61467df0c3 (diff)
downloaddjango-9bbf97bcdb488bb11aebb5bd405549fbec6852cd.tar.gz
Fixed #16055 -- Fixed crash when filtering against char/text GenericRelation relation on PostgreSQL.
-rw-r--r--django/db/backends/base/operations.py7
-rw-r--r--django/db/backends/oracle/features.py3
-rw-r--r--django/db/backends/postgresql/operations.py11
-rw-r--r--django/db/models/fields/related.py8
-rw-r--r--django/db/models/fields/reverse_related.py3
-rw-r--r--django/db/models/sql/datastructures.py33
-rw-r--r--tests/backends/base/test_operations.py15
-rw-r--r--tests/backends/postgresql/test_operations.py32
-rw-r--r--tests/foreign_object/models/empty_join.py2
-rw-r--r--tests/generic_relations_regress/models.py2
-rw-r--r--tests/generic_relations_regress/tests.py14
11 files changed, 117 insertions, 13 deletions
diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
index d2bc336dd8..6f10e31cd5 100644
--- a/django/db/backends/base/operations.py
+++ b/django/db/backends/base/operations.py
@@ -8,6 +8,7 @@ import sqlparse
from django.conf import settings
from django.db import NotSupportedError, transaction
from django.db.backends import utils
+from django.db.models.expressions import Col
from django.utils import timezone
from django.utils.encoding import force_str
@@ -776,3 +777,9 @@ class BaseDatabaseOperations:
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
return ""
+
+ def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
+ lhs_expr = Col(lhs_table, lhs_field)
+ rhs_expr = Col(rhs_table, rhs_field)
+
+ return lhs_expr, rhs_expr
diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py
index 3d77a615c8..05dc552a98 100644
--- a/django/db/backends/oracle/features.py
+++ b/django/db/backends/oracle/features.py
@@ -120,6 +120,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"migrations.test_operations.OperationTests."
"test_alter_field_pk_fk_db_collation",
},
+ "Oracle doesn't support comparing NCLOB to NUMBER.": {
+ "generic_relations_regress.tests.GenericRelationTests.test_textlink_filter",
+ },
}
django_test_expected_failures = {
# A bug in Django/cx_Oracle with respect to string handling (#23843).
diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
index 18cfcb29cb..aa839f5634 100644
--- a/django/db/backends/postgresql/operations.py
+++ b/django/db/backends/postgresql/operations.py
@@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import (
)
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
+from django.db.models.functions import Cast
from django.utils.regex_helper import _lazy_re_compile
@@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations):
update_fields,
unique_fields,
)
+
+ def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
+ lhs_expr, rhs_expr = super().prepare_join_on_clause(
+ lhs_table, lhs_field, rhs_table, rhs_field
+ )
+
+ if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
+ rhs_expr = Cast(rhs_expr, lhs_field)
+
+ return lhs_expr, rhs_expr
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 0efbe53a0b..7a49861164 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -785,6 +785,14 @@ class ForeignObject(RelatedField):
def get_reverse_joining_columns(self):
return self.get_joining_columns(reverse_join=True)
+ def get_joining_fields(self, reverse_join=False):
+ return tuple(
+ self.reverse_related_fields if reverse_join else self.related_fields
+ )
+
+ def get_reverse_joining_fields(self):
+ return self.get_joining_fields(reverse_join=True)
+
def get_extra_descriptor_filter(self, instance):
"""
Return an extra filter condition for related object fetching when
diff --git a/django/db/models/fields/reverse_related.py b/django/db/models/fields/reverse_related.py
index b7d82f6258..f3da8f8bf2 100644
--- a/django/db/models/fields/reverse_related.py
+++ b/django/db/models/fields/reverse_related.py
@@ -195,6 +195,9 @@ class ForeignObjectRel(FieldCacheMixin):
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
+ def get_joining_fields(self):
+ return self.field.get_reverse_joining_fields()
+
def get_extra_restriction(self, alias, related_alias):
return self.field.get_extra_restriction(related_alias, alias)
diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
index 069eb1a301..46a977188a 100644
--- a/django/db/models/sql/datastructures.py
+++ b/django/db/models/sql/datastructures.py
@@ -61,7 +61,15 @@ class Join:
self.join_type = join_type
# A list of 2-tuples to use in the ON clause of the JOIN.
# Each 2-tuple will create one join condition in the ON clause.
- self.join_cols = join_field.get_joining_columns()
+ if hasattr(join_field, "get_joining_fields"):
+ self.join_fields = join_field.get_joining_fields()
+ self.join_cols = tuple(
+ (lhs_field.column, rhs_field.column)
+ for lhs_field, rhs_field in self.join_fields
+ )
+ else:
+ self.join_fields = None
+ self.join_cols = join_field.get_joining_columns()
# Along which field (or ForeignObjectRel in the reverse join case)
self.join_field = join_field
# Is this join nullabled?
@@ -78,18 +86,21 @@ class Join:
params = []
qn = compiler.quote_name_unless_alias
qn2 = connection.ops.quote_name
-
# Add a join condition for each pair of joining columns.
- for lhs_col, rhs_col in self.join_cols:
- join_conditions.append(
- "%s.%s = %s.%s"
- % (
- qn(self.parent_alias),
- qn2(lhs_col),
- qn(self.table_alias),
- qn2(rhs_col),
+ join_fields = self.join_fields or self.join_cols
+ for lhs, rhs in join_fields:
+ if isinstance(lhs, str):
+ lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
+ rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
+ else:
+ lhs, rhs = connection.ops.prepare_join_on_clause(
+ self.parent_alias, lhs, self.table_alias, rhs
)
- )
+ lhs_sql, lhs_params = compiler.compile(lhs)
+ lhs_full_name = lhs_sql % lhs_params
+ rhs_sql, rhs_params = compiler.compile(rhs)
+ rhs_full_name = rhs_sql % rhs_params
+ join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
# Add a single condition inside parentheses for whatever
# get_extra_restriction() returns.
diff --git a/tests/backends/base/test_operations.py b/tests/backends/base/test_operations.py
index 5260344da7..9d2828c8ce 100644
--- a/tests/backends/base/test_operations.py
+++ b/tests/backends/base/test_operations.py
@@ -4,6 +4,7 @@ from django.core.management.color import no_style
from django.db import NotSupportedError, connection, transaction
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import DurationField, Value
+from django.db.models.expressions import Col
from django.test import (
SimpleTestCase,
TestCase,
@@ -159,6 +160,20 @@ class SimpleDatabaseOperationTests(SimpleTestCase):
):
self.ops.datetime_extract_sql(None, None, None, None)
+ def test_prepare_join_on_clause(self):
+ author_table = Author._meta.db_table
+ author_id_field = Author._meta.get_field("id")
+ book_table = Book._meta.db_table
+ book_fk_field = Book._meta.get_field("author")
+ lhs_expr, rhs_expr = self.ops.prepare_join_on_clause(
+ author_table,
+ author_id_field,
+ book_table,
+ book_fk_field,
+ )
+ self.assertEqual(lhs_expr, Col(author_table, author_id_field))
+ self.assertEqual(rhs_expr, Col(book_table, book_fk_field))
+
class DatabaseOperationTests(TestCase):
def setUp(self):
diff --git a/tests/backends/postgresql/test_operations.py b/tests/backends/postgresql/test_operations.py
index c2f2417923..632928ff87 100644
--- a/tests/backends/postgresql/test_operations.py
+++ b/tests/backends/postgresql/test_operations.py
@@ -2,9 +2,11 @@ import unittest
from django.core.management.color import no_style
from django.db import connection
+from django.db.models.expressions import Col
+from django.db.models.functions import Cast
from django.test import SimpleTestCase
-from ..models import Person, Tag
+from ..models import Author, Book, Person, Tag
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests.")
@@ -48,3 +50,31 @@ class PostgreSQLOperationsTests(SimpleTestCase):
),
['TRUNCATE "backends_person", "backends_tag" RESTART IDENTITY CASCADE;'],
)
+
+ def test_prepare_join_on_clause_same_type(self):
+ author_table = Author._meta.db_table
+ author_id_field = Author._meta.get_field("id")
+ lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause(
+ author_table,
+ author_id_field,
+ author_table,
+ author_id_field,
+ )
+ self.assertEqual(lhs_expr, Col(author_table, author_id_field))
+ self.assertEqual(rhs_expr, Col(author_table, author_id_field))
+
+ def test_prepare_join_on_clause_different_types(self):
+ author_table = Author._meta.db_table
+ author_id_field = Author._meta.get_field("id")
+ book_table = Book._meta.db_table
+ book_fk_field = Book._meta.get_field("author")
+ lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause(
+ author_table,
+ author_id_field,
+ book_table,
+ book_fk_field,
+ )
+ self.assertEqual(lhs_expr, Col(author_table, author_id_field))
+ self.assertEqual(
+ rhs_expr, Cast(Col(book_table, book_fk_field), author_id_field)
+ )
diff --git a/tests/foreign_object/models/empty_join.py b/tests/foreign_object/models/empty_join.py
index 9c0ada378c..4c3839dcc1 100644
--- a/tests/foreign_object/models/empty_join.py
+++ b/tests/foreign_object/models/empty_join.py
@@ -50,7 +50,7 @@ class StartsWithRelation(models.ForeignObject):
from_field = self.model._meta.get_field(self.from_fields[0])
return StartsWith(to_field.get_col(alias), from_field.get_col(related_alias))
- def get_joining_columns(self, reverse_join=False):
+ def get_joining_fields(self, reverse_join=False):
return ()
def get_path_info(self, filtered_relation=None):
diff --git a/tests/generic_relations_regress/models.py b/tests/generic_relations_regress/models.py
index dc55b2a83b..6867747a26 100644
--- a/tests/generic_relations_regress/models.py
+++ b/tests/generic_relations_regress/models.py
@@ -64,12 +64,14 @@ class CharLink(models.Model):
content_type = models.ForeignKey(ContentType, models.CASCADE)
object_id = models.CharField(max_length=100)
content_object = GenericForeignKey()
+ value = models.CharField(max_length=250)
class TextLink(models.Model):
content_type = models.ForeignKey(ContentType, models.CASCADE)
object_id = models.TextField()
content_object = GenericForeignKey()
+ value = models.CharField(max_length=250)
class OddRelation1(models.Model):
diff --git a/tests/generic_relations_regress/tests.py b/tests/generic_relations_regress/tests.py
index 9b2f21b88b..b7ecb499eb 100644
--- a/tests/generic_relations_regress/tests.py
+++ b/tests/generic_relations_regress/tests.py
@@ -72,6 +72,20 @@ class GenericRelationTests(TestCase):
TextLink.objects.create(content_object=oddrel)
oddrel.delete()
+ def test_charlink_filter(self):
+ oddrel = OddRelation1.objects.create(name="clink")
+ CharLink.objects.create(content_object=oddrel, value="value")
+ self.assertSequenceEqual(
+ OddRelation1.objects.filter(clinks__value="value"), [oddrel]
+ )
+
+ def test_textlink_filter(self):
+ oddrel = OddRelation2.objects.create(name="clink")
+ TextLink.objects.create(content_object=oddrel, value="value")
+ self.assertSequenceEqual(
+ OddRelation2.objects.filter(tlinks__value="value"), [oddrel]
+ )
+
def test_coerce_object_id_remote_field_cache_persistence(self):
restaurant = Restaurant.objects.create()
CharLink.objects.create(content_object=restaurant)