From 9bbf97bcdb488bb11aebb5bd405549fbec6852cd Mon Sep 17 00:00:00 2001 From: David Wobrock Date: Tue, 18 Apr 2023 10:19:06 +0200 Subject: Fixed #16055 -- Fixed crash when filtering against char/text GenericRelation relation on PostgreSQL. --- django/db/backends/base/operations.py | 7 ++++++ django/db/backends/oracle/features.py | 3 +++ django/db/backends/postgresql/operations.py | 11 ++++++++++ django/db/models/fields/related.py | 8 +++++++ django/db/models/fields/reverse_related.py | 3 +++ django/db/models/sql/datastructures.py | 33 +++++++++++++++++++---------- 6 files changed, 54 insertions(+), 11 deletions(-) (limited to 'django') diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index d2bc336dd8..6f10e31cd5 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -8,6 +8,7 @@ import sqlparse from django.conf import settings from django.db import NotSupportedError, transaction from django.db.backends import utils +from django.db.models.expressions import Col from django.utils import timezone from django.utils.encoding import force_str @@ -776,3 +777,9 @@ class BaseDatabaseOperations: def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): return "" + + def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): + lhs_expr = Col(lhs_table, lhs_field) + rhs_expr = Col(rhs_table, rhs_field) + + return lhs_expr, rhs_expr diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 3d77a615c8..05dc552a98 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -120,6 +120,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): "migrations.test_operations.OperationTests." "test_alter_field_pk_fk_db_collation", }, + "Oracle doesn't support comparing NCLOB to NUMBER.": { + "generic_relations_regress.tests.GenericRelationTests.test_textlink_filter", + }, } django_test_expected_failures = { # A bug in Django/cx_Oracle with respect to string handling (#23843). diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 18cfcb29cb..aa839f5634 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import ( ) from django.db.backends.utils import split_tzname_delta from django.db.models.constants import OnConflict +from django.db.models.functions import Cast from django.utils.regex_helper import _lazy_re_compile @@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations): update_fields, unique_fields, ) + + def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): + lhs_expr, rhs_expr = super().prepare_join_on_clause( + lhs_table, lhs_field, rhs_table, rhs_field + ) + + if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection): + rhs_expr = Cast(rhs_expr, lhs_field) + + return lhs_expr, rhs_expr diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 0efbe53a0b..7a49861164 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -785,6 +785,14 @@ class ForeignObject(RelatedField): def get_reverse_joining_columns(self): return self.get_joining_columns(reverse_join=True) + def get_joining_fields(self, reverse_join=False): + return tuple( + self.reverse_related_fields if reverse_join else self.related_fields + ) + + def get_reverse_joining_fields(self): + return self.get_joining_fields(reverse_join=True) + def get_extra_descriptor_filter(self, instance): """ Return an extra filter condition for related object fetching when diff --git a/django/db/models/fields/reverse_related.py b/django/db/models/fields/reverse_related.py index b7d82f6258..f3da8f8bf2 100644 --- a/django/db/models/fields/reverse_related.py +++ b/django/db/models/fields/reverse_related.py @@ -195,6 +195,9 @@ class ForeignObjectRel(FieldCacheMixin): def get_joining_columns(self): return self.field.get_reverse_joining_columns() + def get_joining_fields(self): + return self.field.get_reverse_joining_fields() + def get_extra_restriction(self, alias, related_alias): return self.field.get_extra_restriction(related_alias, alias) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 069eb1a301..46a977188a 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -61,7 +61,15 @@ class Join: self.join_type = join_type # A list of 2-tuples to use in the ON clause of the JOIN. # Each 2-tuple will create one join condition in the ON clause. - self.join_cols = join_field.get_joining_columns() + if hasattr(join_field, "get_joining_fields"): + self.join_fields = join_field.get_joining_fields() + self.join_cols = tuple( + (lhs_field.column, rhs_field.column) + for lhs_field, rhs_field in self.join_fields + ) + else: + self.join_fields = None + self.join_cols = join_field.get_joining_columns() # Along which field (or ForeignObjectRel in the reverse join case) self.join_field = join_field # Is this join nullabled? @@ -78,18 +86,21 @@ class Join: params = [] qn = compiler.quote_name_unless_alias qn2 = connection.ops.quote_name - # Add a join condition for each pair of joining columns. - for lhs_col, rhs_col in self.join_cols: - join_conditions.append( - "%s.%s = %s.%s" - % ( - qn(self.parent_alias), - qn2(lhs_col), - qn(self.table_alias), - qn2(rhs_col), + join_fields = self.join_fields or self.join_cols + for lhs, rhs in join_fields: + if isinstance(lhs, str): + lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs)) + rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs)) + else: + lhs, rhs = connection.ops.prepare_join_on_clause( + self.parent_alias, lhs, self.table_alias, rhs ) - ) + lhs_sql, lhs_params = compiler.compile(lhs) + lhs_full_name = lhs_sql % lhs_params + rhs_sql, rhs_params = compiler.compile(rhs) + rhs_full_name = rhs_sql % rhs_params + join_conditions.append(f"{lhs_full_name} = {rhs_full_name}") # Add a single condition inside parentheses for whatever # get_extra_restriction() returns. -- cgit v1.2.1