summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/db/models/fields/related.py98
-rw-r--r--django/db/models/fields/related_lookups.py130
-rw-r--r--django/db/models/query_utils.py16
-rw-r--r--django/db/models/sql/query.py47
-rw-r--r--tests/generic_relations_regress/tests.py16
-rw-r--r--tests/queries/tests.py8
6 files changed, 220 insertions, 95 deletions
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 0ca4e59633..e91de27876 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -15,7 +15,10 @@ from django.db.models.fields import (
BLANK_CHOICE_DASH, AutoField, Field, IntegerField, PositiveIntegerField,
PositiveSmallIntegerField,
)
-from django.db.models.lookups import IsNull
+from django.db.models.fields.related_lookups import (
+ RelatedExact, RelatedGreaterThan, RelatedGreaterThanOrEqual, RelatedIn,
+ RelatedLessThan, RelatedLessThanOrEqual,
+)
from django.db.models.query import QuerySet
from django.db.models.query_utils import PathInfo
from django.utils import six
@@ -1336,6 +1339,16 @@ class ForeignObjectRel(object):
def one_to_one(self):
return self.field.one_to_one
+ def get_prep_lookup(self, lookup_name, value):
+ return self.field.get_prep_lookup(lookup_name, value)
+
+ def get_internal_type(self):
+ return self.field.get_internal_type()
+
+ @property
+ def db_type(self):
+ return self.field.db_type
+
def __repr__(self):
return '<%s: %s.%s>' % (
type(self).__name__,
@@ -1760,67 +1773,25 @@ class ForeignObject(RelatedField):
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
return pathinfos
- def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups,
- raw_value):
- from django.db.models.sql.where import SubqueryConstraint, AND, OR
- root_constraint = constraint_class()
- assert len(targets) == len(sources)
- if len(lookups) > 1:
- raise exceptions.FieldError(
- "Cannot resolve keyword %r into field. Choices are: %s" % (
- lookups[0],
- ", ".join(f.name for f in self.model._meta.get_fields()),
- )
- )
- lookup_type = lookups[0]
-
- def get_normalized_value(value):
- from django.db.models import Model
- if isinstance(value, Model):
- value_list = []
- for source in sources:
- # Account for one-to-one relations when sent a different model
- while not isinstance(value, source.model) and source.rel:
- source = source.rel.to._meta.get_field(source.rel.field_name)
- value_list.append(getattr(value, source.attname))
- return tuple(value_list)
- elif not isinstance(value, tuple):
- return (value,)
- return value
-
- is_multicolumn = len(self.related_fields) > 1
- if (hasattr(raw_value, '_as_sql') or
- hasattr(raw_value, 'get_compiler')):
- root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets],
- [source.name for source in sources], raw_value),
- AND)
- elif lookup_type == 'isnull':
- root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND)
- elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
- and not is_multicolumn)):
- value = get_normalized_value(raw_value)
- for target, source, val in zip(targets, sources, value):
- lookup_class = target.get_lookup(lookup_type)
- root_constraint.add(
- lookup_class(target.get_col(alias, source), val), AND)
- elif lookup_type in ['range', 'in'] and not is_multicolumn:
- values = [get_normalized_value(value) for value in raw_value]
- value = [val[0] for val in values]
- lookup_class = targets[0].get_lookup(lookup_type)
- root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND)
- elif lookup_type == 'in':
- values = [get_normalized_value(value) for value in raw_value]
- root_constraint.connector = OR
- for value in values:
- value_constraint = constraint_class()
- for source, target, val in zip(sources, targets, value):
- lookup_class = target.get_lookup('exact')
- lookup = lookup_class(target.get_col(alias, source), val)
- value_constraint.add(lookup, AND)
- root_constraint.add(value_constraint, OR)
- else:
- raise TypeError('Related Field got invalid lookup: %s' % lookup_type)
- return root_constraint
+ def get_lookup(self, lookup_name):
+ if lookup_name == 'in':
+ return RelatedIn
+ elif lookup_name == 'exact':
+ return RelatedExact
+ elif lookup_name == 'gt':
+ return RelatedGreaterThan
+ elif lookup_name == 'gte':
+ return RelatedGreaterThanOrEqual
+ elif lookup_name == 'lt':
+ return RelatedLessThan
+ elif lookup_name == 'lte':
+ return RelatedLessThanOrEqual
+ elif lookup_name != 'isnull':
+ raise TypeError('Related Field got invalid lookup: %s' % lookup_name)
+ return super(ForeignObject, self).get_lookup(lookup_name)
+
+ def get_transform(self, *args, **kwargs):
+ raise NotImplementedError('Relational fields do not support transforms.')
@property
def attnames(self):
@@ -2017,6 +1988,9 @@ class ForeignKey(ForeignObject):
else:
return self.related_field.get_db_prep_save(value, connection=connection)
+ def get_db_prep_value(self, value, connection, prepared=False):
+ return self.related_field.get_db_prep_value(value, connection, prepared)
+
def value_to_string(self, obj):
if not obj:
# In required many-to-one fields with only one available choice,
diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py
new file mode 100644
index 0000000000..b689c9928e
--- /dev/null
+++ b/django/db/models/fields/related_lookups.py
@@ -0,0 +1,130 @@
+from django.db.models.lookups import (
+ Exact, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual,
+)
+
+
+class MultiColSource(object):
+ contains_aggregate = False
+
+ def __init__(self, alias, targets, sources, field):
+ self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
+ self.output_field = self.field
+
+ def __repr__(self):
+ return "{}({}, {})".format(
+ self.__class__.__name__, self.alias, self.field)
+
+ def relabeled_clone(self, relabels):
+ return self.__class__(relabels.get(self.alias, self.alias),
+ self.targets, self.sources, self.field)
+
+
+def get_normalized_value(value, lhs):
+ from django.db.models import Model
+ if isinstance(value, Model):
+ value_list = []
+ # Account for one-to-one relations when sent a different model
+ sources = lhs.output_field.get_path_info()[-1].target_fields
+ for source in sources:
+ while not isinstance(value, source.model) and source.rel:
+ source = source.rel.to._meta.get_field(source.rel.field_name)
+ value_list.append(getattr(value, source.attname))
+ return tuple(value_list)
+ if not isinstance(value, tuple):
+ return (value,)
+ return value
+
+
+class RelatedIn(In):
+ def get_prep_lookup(self):
+ if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
+ # If we get here, we are dealing with single-column relations.
+ self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
+ # We need to run the related field's get_prep_lookup(). Consider case
+ # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
+ # doesn't have validation for non-integers, so we must run validation
+ # using the target field.
+ if hasattr(self.lhs.output_field, 'get_path_info'):
+ # Run the target field's get_prep_lookup. We can safely assume there is
+ # only one as we don't get to the direct value branch otherwise.
+ self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup(
+ self.lookup_name, self.rhs)
+ return super(RelatedIn, self).get_prep_lookup()
+
+ def as_sql(self, compiler, connection):
+ if isinstance(self.lhs, MultiColSource):
+ # For multicolumn lookups we need to build a multicolumn where clause.
+ # This clause is either a SubqueryConstraint (for values that need to be compiled to
+ # SQL) or a OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
+ from django.db.models.sql.where import WhereNode, SubqueryConstraint, AND, OR
+
+ root_constraint = WhereNode(connector=OR)
+ if self.rhs_is_direct_value():
+ values = [get_normalized_value(value, self.lhs) for value in self.rhs]
+ for value in values:
+ value_constraint = WhereNode()
+ for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
+ lookup_class = target.get_lookup('exact')
+ lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
+ value_constraint.add(lookup, AND)
+ root_constraint.add(value_constraint, OR)
+ else:
+ root_constraint.add(
+ SubqueryConstraint(
+ self.lhs.alias, [target.column for target in self.lhs.targets],
+ [source.name for source in self.lhs.sources], self.rhs),
+ AND)
+ return root_constraint.as_sql(compiler, connection)
+ else:
+ return super(RelatedIn, self).as_sql(compiler, connection)
+
+
+class RelatedLookupMixin(object):
+ def get_prep_lookup(self):
+ if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
+ # If we get here, we are dealing with single-column relations.
+ self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
+ # We need to run the related field's get_prep_lookup(). Consider case
+ # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
+ # doesn't have validation for non-integers, so we must run validation
+ # using the target field.
+ if hasattr(self.lhs.output_field, 'get_path_info'):
+ # Get the target field. We can safely assume there is only one
+ # as we don't get to the direct value branch otherwise.
+ self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup(
+ self.lookup_name, self.rhs)
+
+ return super(RelatedLookupMixin, self).get_prep_lookup()
+
+ def as_sql(self, compiler, connection):
+ if isinstance(self.lhs, MultiColSource):
+ assert self.rhs_is_direct_value()
+ self.rhs = get_normalized_value(self.rhs, self.lhs)
+ from django.db.models.sql.where import WhereNode, AND
+ root_constraint = WhereNode()
+ for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
+ lookup_class = target.get_lookup(self.lookup_name)
+ root_constraint.add(
+ lookup_class(target.get_col(self.lhs.alias, source), val), AND)
+ return root_constraint.as_sql(compiler, connection)
+ return super(RelatedLookupMixin, self).as_sql(compiler, connection)
+
+
+class RelatedExact(RelatedLookupMixin, Exact):
+ pass
+
+
+class RelatedLessThan(RelatedLookupMixin, LessThan):
+ pass
+
+
+class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
+ pass
+
+
+class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
+ pass
+
+
+class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
+ pass
diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
index 005bd99956..35e3fe58bc 100644
--- a/django/db/models/query_utils.py
+++ b/django/db/models/query_utils.py
@@ -250,7 +250,7 @@ deferred_class_factory.__safe_for_unpickling__ = True
def refs_aggregate(lookup_parts, aggregates):
"""
- A little helper method to check if the lookup_parts contains references
+ A helper method to check if the lookup_parts contains references
to the given aggregates set. Because the LOOKUP_SEP is contained in the
default annotation names we must check each prefix of the lookup_parts
for a match.
@@ -260,3 +260,17 @@ def refs_aggregate(lookup_parts, aggregates):
if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate:
return aggregates[level_n_lookup], lookup_parts[n:]
return False, ()
+
+
+def refs_expression(lookup_parts, annotations):
+ """
+ A helper method to check if the lookup_parts contains references
+ to the given annotations set. Because the LOOKUP_SEP is contained in the
+ default annotation names we must check each prefix of the lookup_parts
+ for a match.
+ """
+ for n in range(len(lookup_parts) + 1):
+ level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
+ if level_n_lookup in annotations and annotations[level_n_lookup]:
+ return annotations[level_n_lookup], lookup_parts[n:]
+ return False, ()
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index 14e079ee84..9c19e819fd 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -17,7 +17,8 @@ from django.db import DEFAULT_DB_ALIAS, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
-from django.db.models.query_utils import Q, PathInfo, refs_aggregate
+from django.db.models.fields.related_lookups import MultiColSource
+from django.db.models.query_utils import Q, PathInfo, refs_expression
from django.db.models.sql.constants import (
INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE,
)
@@ -1006,7 +1007,7 @@ class Query(object):
"""
lookup_splitted = lookup.split(LOOKUP_SEP)
if self._annotations:
- aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations)
+ aggregate, aggregate_lookups = refs_expression(lookup_splitted, self.annotations)
if aggregate:
return aggregate_lookups, (), aggregate
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
@@ -1157,24 +1158,26 @@ class Query(object):
if can_reuse is not None:
can_reuse.update(join_list)
used_joins = set(used_joins).union(set(join_list))
-
- # Process the join list to see if we can remove any non-needed joins from
- # the far end (fewer tables in a query is better).
targets, alias, join_list = self.trim_joins(sources, join_list, path)
- if hasattr(field, 'get_lookup_constraint'):
- # For now foreign keys get special treatment. This should be
- # refactored when composite fields lands.
- condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
- lookups, value)
- lookup_type = lookups[-1]
- else:
- assert(len(targets) == 1)
- if hasattr(targets[0], 'as_sql'):
- # handle Expressions as annotations
- col = targets[0]
+ if field.is_relation:
+ # No support for transforms for relational fields
+ assert len(lookups) == 1
+ lookup_class = field.get_lookup(lookups[0])
+ # Undo the changes done in setup_joins() if hasattr(final_field, 'field') branch
+ # This hack is needed as long as the field.rel isn't like a real field.
+ if field.get_path_info()[-1].target_fields != sources:
+ target_field = field.rel
else:
- col = targets[0].get_col(alias, field)
+ target_field = field
+ if len(targets) == 1:
+ lhs = targets[0].get_col(alias, target_field)
+ else:
+ lhs = MultiColSource(alias, targets, sources, target_field)
+ condition = lookup_class(lhs, value)
+ lookup_type = lookup_class.lookup_name
+ else:
+ col = targets[0].get_col(alias, field)
condition = self.build_lookup(lookups, col, value)
lookup_type = condition.lookup_name
@@ -1284,14 +1287,6 @@ class Query(object):
)
model = field.model._meta.concrete_model
except FieldDoesNotExist:
- # is it an annotation?
- if self._annotations and name in self._annotations:
- field, model = self._annotations[name], None
- if not field.contains_aggregate:
- # Local non-relational field.
- final_field = field
- targets = (field,)
- break
# We didn't find the current field, so move position back
# one step.
pos -= 1
@@ -1985,7 +1980,7 @@ def is_reverse_o2o(field):
A little helper to check if the given field is reverse-o2o. The field is
expected to be some sort of relation field or related object.
"""
- return not hasattr(field, 'rel') and field.field.unique
+ return field.is_relation and field.one_to_one and not field.concrete
class JoinPromoter(object):
diff --git a/tests/generic_relations_regress/tests.py b/tests/generic_relations_regress/tests.py
index 0d78223725..b6782fe13f 100644
--- a/tests/generic_relations_regress/tests.py
+++ b/tests/generic_relations_regress/tests.py
@@ -144,22 +144,26 @@ class GenericRelationTests(TestCase):
tag.save()
def test_ticket_20378(self):
+ # Create a couple of extra HasLinkThing so that the autopk value
+ # isn't the same for Link and HasLinkThing.
hs1 = HasLinkThing.objects.create()
hs2 = HasLinkThing.objects.create()
- l1 = Link.objects.create(content_object=hs1)
- l2 = Link.objects.create(content_object=hs2)
+ hs3 = HasLinkThing.objects.create()
+ hs4 = HasLinkThing.objects.create()
+ l1 = Link.objects.create(content_object=hs3)
+ l2 = Link.objects.create(content_object=hs4)
self.assertQuerysetEqual(
HasLinkThing.objects.filter(links=l1),
- [hs1], lambda x: x)
+ [hs3], lambda x: x)
self.assertQuerysetEqual(
HasLinkThing.objects.filter(links=l2),
- [hs2], lambda x: x)
+ [hs4], lambda x: x)
self.assertQuerysetEqual(
HasLinkThing.objects.exclude(links=l2),
- [hs1], lambda x: x)
+ [hs1, hs2, hs3], lambda x: x, ordered=False)
self.assertQuerysetEqual(
HasLinkThing.objects.exclude(links=l1),
- [hs2], lambda x: x)
+ [hs1, hs2, hs4], lambda x: x, ordered=False)
def test_ticket_20564(self):
b1 = B.objects.create()
diff --git a/tests/queries/tests.py b/tests/queries/tests.py
index 9c521c1c03..82d6d1fe7c 100644
--- a/tests/queries/tests.py
+++ b/tests/queries/tests.py
@@ -3678,3 +3678,11 @@ class TestTicket24279(TestCase):
School.objects.create()
qs = School.objects.filter(Q(pk__in=()) | Q())
self.assertQuerysetEqual(qs, [])
+
+
+class TestInvalidValuesRelation(TestCase):
+ def test_invalid_values(self):
+ with self.assertRaises(ValueError):
+ Annotation.objects.filter(tag='abc')
+ with self.assertRaises(ValueError):
+ Annotation.objects.filter(tag__in=[123, 'abc'])