diff options
author | Nicolas Delaby <ticosax@free.fr> | 2017-09-22 17:53:17 +0200 |
---|---|---|
committer | Tim Graham <timograham@gmail.com> | 2017-09-22 11:53:17 -0400 |
commit | 01d440fa1e6b5c62acfa8b3fde43dfa1505f93c6 (patch) | |
tree | 21b1f96ecd0fca636746595bce50eb46abdde880 /django/db | |
parent | 3f9d85d95cab228fd881ea952c707022e9e3bdf3 (diff) | |
download | django-01d440fa1e6b5c62acfa8b3fde43dfa1505f93c6.tar.gz |
Fixed #27332 -- Added FilteredRelation API for conditional join (ON clause) support.
Thanks Anssi Kääriäinen for contributing to the patch.
Diffstat (limited to 'django/db')
-rw-r--r-- | django/db/models/__init__.py | 2 | ||||
-rw-r--r-- | django/db/models/fields/related.py | 54 | ||||
-rw-r--r-- | django/db/models/fields/reverse_related.py | 4 | ||||
-rw-r--r-- | django/db/models/options.py | 10 | ||||
-rw-r--r-- | django/db/models/query.py | 41 | ||||
-rw-r--r-- | django/db/models/query_utils.py | 43 | ||||
-rw-r--r-- | django/db/models/sql/compiler.py | 51 | ||||
-rw-r--r-- | django/db/models/sql/datastructures.py | 42 | ||||
-rw-r--r-- | django/db/models/sql/query.py | 137 |
9 files changed, 311 insertions, 73 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index d29addd1f7..628f92db3c 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -20,6 +20,7 @@ from django.db.models.manager import Manager from django.db.models.query import ( Prefetch, Q, QuerySet, prefetch_related_objects, ) +from django.db.models.query_utils import FilteredRelation # Imports that would create circular imports if sorted from django.db.models.base import DEFERRED, Model # isort:skip @@ -69,6 +70,7 @@ __all__ += [ 'Window', 'WindowFrame', 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', + 'FilteredRelation', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'permalink', ] diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 5cf540d385..34123fd4de 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -697,18 +697,33 @@ class ForeignObject(RelatedField): """ return None - def get_path_info(self): + def get_path_info(self, filtered_relation=None): """Get path from this field to the related model.""" opts = self.remote_field.model._meta from_opts = self.model._meta - return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)] - - def get_reverse_path_info(self): + return [PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=self.foreign_related_fields, + join_field=self, + m2m=False, + direct=True, + filtered_relation=filtered_relation, + )] + + def get_reverse_path_info(self, filtered_relation=None): """Get path from the related model to this field's model.""" opts = self.model._meta from_opts = self.remote_field.model._meta - pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] - return pathinfos + return [PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=(opts.pk,), + join_field=self.remote_field, + m2m=not self.unique, + direct=False, + filtered_relation=filtered_relation, + )] @classmethod @functools.lru_cache(maxsize=None) @@ -861,12 +876,19 @@ class ForeignKey(ForeignObject): def target_field(self): return self.foreign_related_fields[0] - def get_reverse_path_info(self): + def get_reverse_path_info(self, filtered_relation=None): """Get path from the related model to this field's model.""" opts = self.model._meta from_opts = self.remote_field.model._meta - pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.remote_field, not self.unique, False)] - return pathinfos + return [PathInfo( + from_opts=from_opts, + to_opts=opts, + target_fields=(opts.pk,), + join_field=self.remote_field, + m2m=not self.unique, + direct=False, + filtered_relation=filtered_relation, + )] def validate(self, value, model_instance): if self.remote_field.parent_link: @@ -1435,7 +1457,7 @@ class ManyToManyField(RelatedField): ) return name, path, args, kwargs - def _get_path_info(self, direct=False): + def _get_path_info(self, direct=False, filtered_relation=None): """Called by both direct and indirect m2m traversal.""" pathinfos = [] int_model = self.remote_field.through @@ -1443,10 +1465,10 @@ class ManyToManyField(RelatedField): linkfield2 = int_model._meta.get_field(self.m2m_reverse_field_name()) if direct: join1infos = linkfield1.get_reverse_path_info() - join2infos = linkfield2.get_path_info() + join2infos = linkfield2.get_path_info(filtered_relation) else: join1infos = linkfield2.get_reverse_path_info() - join2infos = linkfield1.get_path_info() + join2infos = linkfield1.get_path_info(filtered_relation) # Get join infos between the last model of join 1 and the first model # of join 2. Assume the only reason these may differ is due to model @@ -1465,11 +1487,11 @@ class ManyToManyField(RelatedField): pathinfos.extend(join2infos) return pathinfos - def get_path_info(self): - return self._get_path_info(direct=True) + def get_path_info(self, filtered_relation=None): + return self._get_path_info(direct=True, filtered_relation=filtered_relation) - def get_reverse_path_info(self): - return self._get_path_info(direct=False) + def get_reverse_path_info(self, filtered_relation=None): + return self._get_path_info(direct=False, filtered_relation=filtered_relation) def _get_m2m_db_table(self, opts): """ diff --git a/django/db/models/fields/reverse_related.py b/django/db/models/fields/reverse_related.py index 1f42375566..dddb869513 100644 --- a/django/db/models/fields/reverse_related.py +++ b/django/db/models/fields/reverse_related.py @@ -163,8 +163,8 @@ class ForeignObjectRel(FieldCacheMixin): return self.related_name return opts.model_name + ('_set' if self.multiple else '') - def get_path_info(self): - return self.field.get_reverse_path_info() + def get_path_info(self, filtered_relation=None): + return self.field.get_reverse_path_info(filtered_relation) def get_cache_name(self): """ diff --git a/django/db/models/options.py b/django/db/models/options.py index 9f0746bd58..0786e525b3 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -632,7 +632,15 @@ class Options: final_field = opts.parents[int_model] targets = (final_field.remote_field.get_related_field(),) opts = int_model._meta - path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True)) + path.append(PathInfo( + from_opts=final_field.model._meta, + to_opts=opts, + target_fields=targets, + join_field=final_field, + m2m=False, + direct=True, + filtered_relation=None, + )) return path def get_path_from_parent(self, parent): diff --git a/django/db/models/query.py b/django/db/models/query.py index 42fb728190..3bfe0a6fb4 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -22,7 +22,7 @@ from django.db.models.deletion import Collector from django.db.models.expressions import F from django.db.models.fields import AutoField from django.db.models.functions import Trunc -from django.db.models.query_utils import InvalidQuery, Q +from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.utils import timezone from django.utils.deprecation import RemovedInDjango30Warning @@ -953,6 +953,12 @@ class QuerySet: if lookups == (None,): clone._prefetch_related_lookups = () else: + for lookup in lookups: + if isinstance(lookup, Prefetch): + lookup = lookup.prefetch_to + lookup = lookup.split(LOOKUP_SEP, 1)[0] + if lookup in self.query._filtered_relations: + raise ValueError('prefetch_related() is not supported with FilteredRelation.') clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups return clone @@ -984,7 +990,10 @@ class QuerySet: if alias in names: raise ValueError("The annotation '%s' conflicts with a field on " "the model." % alias) - clone.query.add_annotation(annotation, alias, is_summary=False) + if isinstance(annotation, FilteredRelation): + clone.query.add_filtered_relation(annotation, alias) + else: + clone.query.add_annotation(annotation, alias, is_summary=False) for alias, annotation in clone.query.annotations.items(): if alias in annotations and annotation.contains_aggregate: @@ -1060,6 +1069,10 @@ class QuerySet: # Can only pass None to defer(), not only(), as the rest option. # That won't stop people trying to do this, so let's be explicit. raise TypeError("Cannot pass None as an argument to only().") + for field in fields: + field = field.split(LOOKUP_SEP, 1)[0] + if field in self.query._filtered_relations: + raise ValueError('only() is not supported with FilteredRelation.') clone = self._chain() clone.query.add_immediate_loading(fields) return clone @@ -1730,9 +1743,9 @@ class RelatedPopulator: # model's fields. # - related_populators: a list of RelatedPopulator instances if # select_related() descends to related models from this model. - # - field, remote_field: the fields to use for populating the - # internal fields cache. If remote_field is set then we also - # set the reverse link. + # - local_setter, remote_setter: Methods to set cached values on + # the object being populated and on the remote object. Usually + # these are Field.set_cached_value() methods. select_fields = klass_info['select_fields'] from_parent = klass_info['from_parent'] if not from_parent: @@ -1751,16 +1764,8 @@ class RelatedPopulator: self.model_cls = klass_info['model'] self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) self.related_populators = get_related_populators(klass_info, select, self.db) - reverse = klass_info['reverse'] - field = klass_info['field'] - self.remote_field = None - if reverse: - self.field = field.remote_field - self.remote_field = field - else: - self.field = field - if field.unique: - self.remote_field = field.remote_field + self.local_setter = klass_info['local_setter'] + self.remote_setter = klass_info['remote_setter'] def populate(self, row, from_obj): if self.reorder_for_init: @@ -1774,9 +1779,9 @@ class RelatedPopulator: if self.related_populators: for rel_iter in self.related_populators: rel_iter.populate(row, obj) - if self.remote_field: - self.remote_field.set_cached_value(obj, from_obj) - self.field.set_cached_value(from_obj, obj) + self.local_setter(from_obj, obj) + if obj is not None: + self.remote_setter(obj, from_obj) def get_related_populators(klass_info, select, db): diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index e3f6a730d5..8a889264e5 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -16,7 +16,7 @@ from django.utils import tree # PathInfo is used when converting lookups (fk__somecol). The contents # describe the relation in Model terms (model Options and Fields for both # sides of the relation. The join_field is the field backing the relation. -PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct') +PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation') class InvalidQuery(Exception): @@ -291,3 +291,44 @@ def check_rel_lookup_compatibility(model, target_opts, field): check(target_opts) or (getattr(field, 'primary_key', False) and check(field.model._meta)) ) + + +class FilteredRelation: + """Specify custom filtering in the ON clause of SQL joins.""" + + def __init__(self, relation_name, *, condition=Q()): + if not relation_name: + raise ValueError('relation_name cannot be empty.') + self.relation_name = relation_name + self.alias = None + if not isinstance(condition, Q): + raise ValueError('condition argument must be a Q() instance.') + self.condition = condition + self.path = [] + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) and + self.relation_name == other.relation_name and + self.alias == other.alias and + self.condition == other.condition + ) + + def clone(self): + clone = FilteredRelation(self.relation_name, condition=self.condition) + clone.alias = self.alias + clone.path = self.path[:] + return clone + + def resolve_expression(self, *args, **kwargs): + """ + QuerySet.annotate() only accepts expression-like arguments + (with a resolve_expression() method). + """ + raise NotImplementedError('FilteredRelation.resolve_expression() is unused.') + + def as_sql(self, compiler, connection): + # Resolve the condition in Join.filtered_relation. + query = compiler.query + where = query.build_filtered_relation_q(self.condition, reuse=set(self.path)) + return compiler.compile(where) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 11ff51f60f..14d44d3eef 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -702,7 +702,7 @@ class SQLCompiler: """ result = [] params = [] - for alias in self.query.alias_map: + for alias in tuple(self.query.alias_map): if not self.query.alias_refcount[alias]: continue try: @@ -737,7 +737,7 @@ class SQLCompiler: f.field.related_query_name() for f in opts.related_objects if f.field.unique ) - return chain(direct_choices, reverse_choices) + return chain(direct_choices, reverse_choices, self.query._filtered_relations) related_klass_infos = [] if not restricted and cur_depth > self.query.max_depth: @@ -788,7 +788,8 @@ class SQLCompiler: klass_info = { 'model': f.remote_field.model, 'field': f, - 'reverse': False, + 'local_setter': f.set_cached_value, + 'remote_setter': f.remote_field.set_cached_value if f.unique else lambda x, y: None, 'from_parent': False, } related_klass_infos.append(klass_info) @@ -825,7 +826,8 @@ class SQLCompiler: klass_info = { 'model': model, 'field': f, - 'reverse': True, + 'local_setter': f.remote_field.set_cached_value, + 'remote_setter': f.set_cached_value, 'from_parent': from_parent, } related_klass_infos.append(klass_info) @@ -842,6 +844,47 @@ class SQLCompiler: next, restricted) get_related_klass_infos(klass_info, next_klass_infos) fields_not_found = set(requested).difference(fields_found) + for name in list(requested): + # Filtered relations work only on the topmost level. + if cur_depth > 1: + break + if name in self.query._filtered_relations: + fields_found.add(name) + f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias) + model = join_opts.model + alias = joins[-1] + from_parent = issubclass(model, opts.model) and model is not opts.model + + def local_setter(obj, from_obj): + f.remote_field.set_cached_value(from_obj, obj) + + def remote_setter(obj, from_obj): + setattr(from_obj, name, obj) + klass_info = { + 'model': model, + 'field': f, + 'local_setter': local_setter, + 'remote_setter': remote_setter, + 'from_parent': from_parent, + } + related_klass_infos.append(klass_info) + select_fields = [] + columns = self.get_default_columns( + start_alias=alias, opts=model._meta, + from_parent=opts.model, + ) + for col in columns: + select_fields.append(len(select)) + select.append((col, None)) + klass_info['select_fields'] = select_fields + next_requested = requested.get(name, {}) + next_klass_infos = self.get_related_selections( + select, opts=model._meta, root_alias=alias, + cur_depth=cur_depth + 1, requested=next_requested, + restricted=restricted, + ) + get_related_klass_infos(klass_info, next_klass_infos) + fields_not_found = set(requested).difference(fields_found) if fields_not_found: invalid_fields = ("'%s'" % s for s in fields_not_found) raise FieldError( diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 788c2dd669..ab02f65042 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -41,7 +41,7 @@ class Join: - relabeled_clone() """ def __init__(self, table_name, parent_alias, table_alias, join_type, - join_field, nullable): + join_field, nullable, filtered_relation=None): # Join table self.table_name = table_name self.parent_alias = parent_alias @@ -56,6 +56,7 @@ class Join: self.join_field = join_field # Is this join nullabled? self.nullable = nullable + self.filtered_relation = filtered_relation def as_sql(self, compiler, connection): """ @@ -85,7 +86,11 @@ class Join: extra_sql, extra_params = compiler.compile(extra_cond) join_conditions.append('(%s)' % extra_sql) params.extend(extra_params) - + if self.filtered_relation: + extra_sql, extra_params = compiler.compile(self.filtered_relation) + if extra_sql: + join_conditions.append('(%s)' % extra_sql) + params.extend(extra_params) if not join_conditions: # This might be a rel on the other end of an actual declared field. declared_field = getattr(self.join_field, 'field', self.join_field) @@ -101,18 +106,27 @@ class Join: def relabeled_clone(self, change_map): new_parent_alias = change_map.get(self.parent_alias, self.parent_alias) new_table_alias = change_map.get(self.table_alias, self.table_alias) + if self.filtered_relation is not None: + filtered_relation = self.filtered_relation.clone() + filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path] + else: + filtered_relation = None return self.__class__( self.table_name, new_parent_alias, new_table_alias, self.join_type, - self.join_field, self.nullable) + self.join_field, self.nullable, filtered_relation=filtered_relation, + ) + + def equals(self, other, with_filtered_relation): + return ( + isinstance(other, self.__class__) and + self.table_name == other.table_name and + self.parent_alias == other.parent_alias and + self.join_field == other.join_field and + (not with_filtered_relation or self.filtered_relation == other.filtered_relation) + ) def __eq__(self, other): - if isinstance(other, self.__class__): - return ( - self.table_name == other.table_name and - self.parent_alias == other.parent_alias and - self.join_field == other.join_field - ) - return False + return self.equals(other, with_filtered_relation=True) def demote(self): new = self.relabeled_clone({}) @@ -134,6 +148,7 @@ class BaseTable: """ join_type = None parent_alias = None + filtered_relation = None def __init__(self, table_name, alias): self.table_name = table_name @@ -146,3 +161,10 @@ class BaseTable: def relabeled_clone(self, change_map): return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) + + def equals(self, other, with_filtered_relation): + return ( + isinstance(self, other.__class__) and + self.table_name == other.table_name and + self.table_alias == other.table_alias + ) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index dfa369513b..a962aabdf1 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -45,6 +45,14 @@ def get_field_names_from_opts(opts): )) +def get_children_from_q(q): + for child in q.children: + if isinstance(child, Node): + yield from get_children_from_q(child) + else: + yield child + + JoinInfo = namedtuple( 'JoinInfo', ('final_field', 'targets', 'opts', 'joins', 'path') @@ -210,6 +218,8 @@ class Query: # load. self.deferred_loading = (frozenset(), True) + self._filtered_relations = {} + @property def extra(self): if self._extra is None: @@ -311,6 +321,7 @@ class Query: if 'subq_aliases' in self.__dict__: obj.subq_aliases = self.subq_aliases.copy() obj.used_aliases = self.used_aliases.copy() + obj._filtered_relations = self._filtered_relations.copy() # Clear the cached_property try: del obj.base_table @@ -624,6 +635,8 @@ class Query: opts = orig_opts for name in parts[:-1]: old_model = cur_model + if name in self._filtered_relations: + name = self._filtered_relations[name].relation_name source = opts.get_field(name) if is_reverse_o2o(source): cur_model = source.related_model @@ -684,7 +697,7 @@ class Query: for model, values in seen.items(): callback(target, model, values) - def table_alias(self, table_name, create=False): + def table_alias(self, table_name, create=False, filtered_relation=None): """ Return a table alias for the given table_name and whether this is a new alias or not. @@ -704,8 +717,8 @@ class Query: alias_list.append(alias) else: # The first occurrence of a table uses the table name directly. - alias = table_name - self.table_map[alias] = [alias] + alias = filtered_relation.alias if filtered_relation is not None else table_name + self.table_map[table_name] = [alias] self.alias_refcount[alias] = 1 return alias, True @@ -881,7 +894,7 @@ class Query: """ return len([1 for count in self.alias_refcount.values() if count]) - def join(self, join, reuse=None): + def join(self, join, reuse=None, reuse_with_filtered_relation=False): """ Return an alias for the 'join', either reusing an existing alias for that join or creating a new one. 'join' is either a @@ -890,18 +903,29 @@ class Query: The 'reuse' parameter can be either None which means all joins are reusable, or it can be a set containing the aliases that can be reused. + The 'reuse_with_filtered_relation' parameter is used when computing + FilteredRelation instances. + A join is always created as LOUTER if the lhs alias is LOUTER to make sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new joins are created as LOUTER if the join is nullable. """ - reuse = [a for a, j in self.alias_map.items() - if (reuse is None or a in reuse) and j == join] - if reuse: - self.ref_alias(reuse[0]) - return reuse[0] + if reuse_with_filtered_relation and reuse: + reuse_aliases = [ + a for a, j in self.alias_map.items() + if a in reuse and j.equals(join, with_filtered_relation=False) + ] + else: + reuse_aliases = [ + a for a, j in self.alias_map.items() + if (reuse is None or a in reuse) and j == join + ] + if reuse_aliases: + self.ref_alias(reuse_aliases[0]) + return reuse_aliases[0] # No reuse is possible, so we need a new alias. - alias, _ = self.table_alias(join.table_name, create=True) + alias, _ = self.table_alias(join.table_name, create=True, filtered_relation=join.filtered_relation) if join.join_type: if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable: join_type = LOUTER @@ -1090,7 +1114,8 @@ class Query: (name, lhs.output_field.__class__.__name__)) def build_filter(self, filter_expr, branch_negated=False, current_negated=False, - can_reuse=None, allow_joins=True, split_subq=True): + can_reuse=None, allow_joins=True, split_subq=True, + reuse_with_filtered_relation=False): """ Build a WhereNode for a single filter clause but don't add it to this Query. Query.add_q() will then add this filter to the where @@ -1112,6 +1137,9 @@ class Query: The 'can_reuse' is a set of reusable joins for multijoins. + If 'reuse_with_filtered_relation' is True, then only joins in can_reuse + will be reused. + The method will create a filter clause that can be added to the current query. However, if the filter isn't added to the query then the caller is responsible for unreffing the joins used. @@ -1147,7 +1175,10 @@ class Query: allow_many = not branch_negated or not split_subq try: - join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many) + join_info = self.setup_joins( + parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many, + reuse_with_filtered_relation=reuse_with_filtered_relation, + ) # Prevent iterator from being consumed by check_related_objects() if isinstance(value, Iterator): @@ -1250,6 +1281,41 @@ class Query: needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner + def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, current_negated=False): + """Add a FilteredRelation object to the current filter.""" + connector = q_object.connector + current_negated ^= q_object.negated + branch_negated = branch_negated or q_object.negated + target_clause = self.where_class(connector=connector, negated=q_object.negated) + for child in q_object.children: + if isinstance(child, Node): + child_clause = self.build_filtered_relation_q( + child, reuse=reuse, branch_negated=branch_negated, + current_negated=current_negated, + ) + else: + child_clause, _ = self.build_filter( + child, can_reuse=reuse, branch_negated=branch_negated, + current_negated=current_negated, + allow_joins=True, split_subq=False, + reuse_with_filtered_relation=True, + ) + target_clause.add(child_clause, connector) + return target_clause + + def add_filtered_relation(self, filtered_relation, alias): + filtered_relation.alias = alias + lookups = dict(get_children_from_q(filtered_relation.condition)) + for lookup in chain((filtered_relation.relation_name,), lookups): + lookup_parts, field_parts, _ = self.solve_lookup_type(lookup) + shift = 2 if not lookup_parts else 1 + if len(field_parts) > (shift + len(lookup_parts)): + raise ValueError( + "FilteredRelation's condition doesn't support nested " + "relations (got %r)." % lookup + ) + self._filtered_relations[filtered_relation.alias] = filtered_relation + def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): """ Walk the list of names and turns them into PathInfo tuples. A single @@ -1272,12 +1338,15 @@ class Query: name = opts.pk.name field = None + filtered_relation = None try: field = opts.get_field(name) except FieldDoesNotExist: if name in self.annotation_select: field = self.annotation_select[name].output_field - + elif name in self._filtered_relations and pos == 0: + filtered_relation = self._filtered_relations[name] + field = opts.get_field(filtered_relation.relation_name) if field is not None: # Fields that contain one-to-many relations with a generic # model (like a GenericForeignKey) cannot generate reverse @@ -1301,7 +1370,10 @@ class Query: pos -= 1 if pos == -1 or fail_on_missing: field_names = list(get_field_names_from_opts(opts)) - available = sorted(field_names + list(self.annotation_select)) + available = sorted( + field_names + list(self.annotation_select) + + list(self._filtered_relations) + ) raise FieldError("Cannot resolve keyword '%s' into field. " "Choices are: %s" % (name, ", ".join(available))) break @@ -1315,7 +1387,7 @@ class Query: cur_names_with_path[1].extend(path_to_parent) opts = path_to_parent[-1].to_opts if hasattr(field, 'get_path_info'): - pathinfos = field.get_path_info() + pathinfos = field.get_path_info(filtered_relation) if not allow_many: for inner_pos, p in enumerate(pathinfos): if p.m2m: @@ -1340,7 +1412,8 @@ class Query: break return path, final_field, targets, names[pos + 1:] - def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): + def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, + reuse_with_filtered_relation=False): """ Compute the necessary table joins for the passage through the fields given in 'names'. 'opts' is the Options class for the current model @@ -1352,6 +1425,9 @@ class Query: that can be reused. Note that non-reverse foreign keys are always reusable when using setup_joins(). + The 'reuse_with_filtered_relation' can be used to force 'can_reuse' + parameter and force the relation on the given connections. + If 'allow_many' is False, then any reverse foreign key seen will generate a MultiJoin exception. @@ -1374,15 +1450,29 @@ class Query: # joins at this stage - we will need the information about join type # of the trimmed joins. for join in path: + if join.filtered_relation: + filtered_relation = join.filtered_relation.clone() + table_alias = filtered_relation.alias + else: + filtered_relation = None + table_alias = None opts = join.to_opts if join.direct: nullable = self.is_nullable(join.join_field) else: nullable = True - connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable) - reuse = can_reuse if join.m2m else None - alias = self.join(connection, reuse=reuse) + connection = Join( + opts.db_table, alias, table_alias, INNER, join.join_field, + nullable, filtered_relation=filtered_relation, + ) + reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None + alias = self.join( + connection, reuse=reuse, + reuse_with_filtered_relation=reuse_with_filtered_relation, + ) joins.append(alias) + if filtered_relation: + filtered_relation.path = joins[:] return JoinInfo(final_field, targets, opts, joins, path) def trim_joins(self, targets, joins, path): @@ -1402,6 +1492,8 @@ class Query: for pos, info in enumerate(reversed(path)): if len(joins) == 1 or not info.direct: break + if info.filtered_relation: + break join_targets = {t.column for t in info.join_field.foreign_related_fields} cur_targets = {t.column for t in targets} if not cur_targets.issubset(join_targets): @@ -1425,7 +1517,7 @@ class Query: return self.annotation_select[name] else: field_list = name.split(LOOKUP_SEP) - join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse) + join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse) targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) if len(targets) > 1: raise FieldError("Referencing multicolumn fields with F() objects " @@ -1602,7 +1694,10 @@ class Query: # from the model on which the lookup failed. raise else: - names = sorted(list(get_field_names_from_opts(opts)) + list(self.extra) + list(self.annotation_select)) + names = sorted( + list(get_field_names_from_opts(opts)) + list(self.extra) + + list(self.annotation_select) + list(self._filtered_relations) + ) raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) |