diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2018-11-12 18:27:34 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2018-11-14 21:35:15 -0500 |
| commit | 7dcfd1e019e1c0ebceba06d6684f5bf64a2efb71 (patch) | |
| tree | 35a6a875d1cee208b22884919359128b078eaae6 /lib/sqlalchemy | |
| parent | a698bdbc5716201804ddedde6a0fc5ab33d43300 (diff) | |
| download | sqlalchemy-7dcfd1e019e1c0ebceba06d6684f5bf64a2efb71.tar.gz | |
Allow join() to pick the best candidate from multiple froms/entities
Refactored :meth:`.Query.join` to further clarify the individual components
of structuring the join. This refactor adds the ability for
:meth:`.Query.join` to determine the most appropriate "left" side of the
join when there is more than one element in the FROM list or the query is
against multiple entities. In particular this targets the regression we
saw in :ticket:`4363` but is also of general use. The codepaths within
:meth:`.Query.join` are now easier to follow and the error cases are
decided more specifically at an earlier point in the operation.
Fixes: #4365
Change-Id: I403f451243904a020ceab4c3f94bead550c7b2d5
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 426 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 89 |
4 files changed, 403 insertions, 142 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index cace2e54a..93b9a85be 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2182,28 +2182,37 @@ class Query(object): # convert to a tuple. keys = (keys,) + # Query.join() accepts a list of join paths all at once. + # step one is to iterate through these paths and determine the + # intent of each path individually. as we encounter a path token, + # we add a new ORMJoin construct to the self._from_obj tuple, + # either by adding a new element to it, or by replacing an existing + # element with a new ORMJoin. keylist = util.to_list(keys) for idx, arg1 in enumerate(keylist): if isinstance(arg1, tuple): # "tuple" form of join, multiple # tuples are accepted as well. The simpler - # "2-arg" form is preferred. May deprecate - # the "tuple" usage. + # "2-arg" form is preferred. arg1, arg2 = arg1 else: arg2 = None # determine onclause/right_entity. there # is a little bit of legacy behavior still at work here - # which means they might be in either order. may possibly - # lock this down to (right_entity, onclause) in 0.6. + # which means they might be in either order. if isinstance( arg1, (interfaces.PropComparator, util.string_types)): - right_entity, onclause = arg2, arg1 + right, onclause = arg2, arg1 else: - right_entity, onclause = arg1, arg2 + right, onclause = arg1, arg2 - left_entity = prop = None + if onclause is None: + r_info = inspect(right) + if not r_info.is_selectable and not hasattr(r_info, 'mapper'): + raise sa_exc.ArgumentError( + "Expected mapped entity or " + "selectable/table as join target") if isinstance(onclause, interfaces.PropComparator): of_type = getattr(onclause, '_of_type', None) @@ -2211,44 +2220,46 @@ class Query(object): of_type = None if isinstance(onclause, util.string_types): - left_entity = self._joinpoint_zero() - - descriptor = _entity_descriptor(left_entity, onclause) - onclause = descriptor + # string given, e.g. query(Foo).join("bar"). + # we look to the left entity or what we last joined + # towards + onclause = _entity_descriptor(self._joinpoint_zero(), onclause) # check for q.join(Class.propname, from_joinpoint=True) - # and Class is that of the current joinpoint + # and Class corresponds at the mapper level to the current + # joinpoint. this match intentionally looks for a non-aliased + # class-bound descriptor as the onclause and if it matches the + # current joinpoint at the mapper level, it's used. This + # is a very old use case that is intended to make it easier + # to work with the aliased=True flag, which is also something + # that probably shouldn't exist on join() due to its high + # complexity/usefulness ratio elif from_joinpoint and \ isinstance(onclause, interfaces.PropComparator): - left_entity = onclause._parententity + jp0 = self._joinpoint_zero() + info = inspect(jp0) - info = inspect(self._joinpoint_zero()) - left_mapper, left_selectable, left_is_aliased = \ - getattr(info, 'mapper', None), \ - info.selectable, \ - getattr(info, 'is_aliased_class', None) - - if left_mapper is left_entity: - left_entity = self._joinpoint_zero() - descriptor = _entity_descriptor(left_entity, - onclause.key) - onclause = descriptor + if getattr(info, 'mapper', None) is onclause._parententity: + onclause = _entity_descriptor(jp0, onclause.key) if isinstance(onclause, interfaces.PropComparator): - if right_entity is None: + # descriptor/property given (or determined); this tells + # us explicitly what the expected "left" side of the join is. + if right is None: if of_type: - right_entity = of_type + right = of_type else: - right_entity = onclause.property.mapper + right = onclause.property.mapper + + left = onclause._parententity - left_entity = onclause._parententity + alias = self._polymorphic_adapters.get(left, None) - alias = self._polymorphic_adapters.get(left_entity, None) # could be None or could be ColumnAdapter also if isinstance(alias, ORMAdapter) and \ - alias.mapper.isa(left_entity): - left_entity = alias.aliased_class - onclause = getattr(left_entity, onclause.key) + alias.mapper.isa(left): + left = alias.aliased_class + onclause = getattr(left, onclause.key) prop = onclause.property if not isinstance(onclause, attributes.QueryableAttribute): @@ -2257,7 +2268,7 @@ class Query(object): if not create_aliases: # check for this path already present. # don't render in that case. - edge = (left_entity, right_entity, prop.key) + edge = (left, right, prop.key) if edge in self._joinpoint: # The child's prev reference might be stale -- # it could point to a parent older than the @@ -2270,50 +2281,248 @@ class Query(object): jp['prev'] = (edge, self._joinpoint) self._update_joinpoint(jp) + # warn only on the last element of the list if idx == len(keylist) - 1: util.warn( "Pathed join target %s has already " "been joined to; skipping" % prop) continue + else: + # no descriptor/property given; we will need to figure out + # what the effective "left" side is + prop = left = None - elif onclause is not None and right_entity is None: - # TODO: no coverage here - raise NotImplementedError("query.join(a==b) not supported.") - + # figure out the final "left" and "right" sides and create an + # ORMJoin to add to our _from_obj tuple self._join_left_to_right( - left_entity, - right_entity, onclause, - outerjoin, full, create_aliases, prop) + left, right, onclause, prop, create_aliases, + outerjoin, full + ) - def _join_left_to_right(self, left, right, - onclause, outerjoin, full, create_aliases, prop): - """append a JOIN to the query's from clause.""" + def _join_left_to_right( + self, left, right, onclause, prop, + create_aliases, outerjoin, full): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ self._polymorphic_adapters = self._polymorphic_adapters.copy() if left is None: - if self._from_obj: - left = self._from_obj[0] - elif self._entities: - left = self._entities[0].entity_zero_or_selectable + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + left, replace_from_obj_index, use_entity_index = \ + self._join_determine_implicit_left_side(left, right, onclause) + else: + # left is given via a relationship/name. Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + assert prop is not None + replace_from_obj_index, use_entity_index = \ + self._join_place_explicit_left_side(left) + + # this should never happen because we would not have found a place + # to join on + assert left is not right or create_aliases + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the the new effective "right" side + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, create_aliases, + ) - if left is None: - if self._entities: - problem = "Don't know how to join from %s" % self._entities[0] + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self._from_obj[replace_from_obj_index] + + self._from_obj = ( + self._from_obj[:replace_from_obj_index] + + (orm_join( + left_clause, right, + onclause, isouter=outerjoin, full=full), ) + + self._from_obj[replace_from_obj_index + 1:]) + else: + # add a new element to the self._from_obj list + + if use_entity_index is not None: + # why doesn't this work as .entity_zero_or_selectable? + left_clause = self._entities[use_entity_index].selectable + else: + left_clause = left + + self._from_obj = self._from_obj + ( + orm_join( + left_clause, right, onclause, + isouter=outerjoin, full=full), + ) + + def _join_determine_implicit_left_side(self, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self._from_obj: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + indexes = sql_util.find_left_clause_to_join_from( + self._from_obj, + r_info.selectable, onclause) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self._from_obj[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Try adding an explicit ON clause " + "to help resolve the ambiguity.") else: - problem = "No entities to join from" + raise sa_exc.InvalidRequestError( + "Don't know how to join to %s; please use " + "an ON clause to more clearly establish the left " + "side of this join" % (right, ) + ) + + elif self._entities: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(self._entities): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, _MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + all_clauses = list(potential.keys()) + indexes = sql_util.find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Try adding an explicit ON clause " + "to help resolve the ambiguity.") + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %s; please use " + "an ON clause to more clearly establish the left " + "side of this join" % (right, ) + ) + else: raise sa_exc.InvalidRequestError( - "%s; please use " + "No entities to join from; please use " "select_from() to establish the left " - "entity/selectable of this join" % problem) + "entity/selectable of this join") - if left is right and \ - not create_aliases: - raise sa_exc.InvalidRequestError( - "Can't construct a join from %s to %s, they " - "are the same entity" % - (left, right)) + return left, replace_from_obj_index, use_entity_index + + def _join_place_explicit_left_side(self, left): + """When join conditions express a left side explicitly, determine + where in our existing list of FROM clauses we should join towards, + or if we need to make a new join, and if so is it from one of our + existing entities. + + """ + + # when we are here, it means join() was called with an indicator + # as to an exact left side, which means a path to a + # RelationshipProperty was given, e.g.: + # + # join(RightEntity, LeftEntity.right) + # + # or + # + # join(LeftEntity.right) + # + # as well as string forms: + # + # join(RightEntity, "right") + # + # etc. + # + + replace_from_obj_index = use_entity_index = None + + l_info = inspect(left) + if self._from_obj: + indexes = sql_util.find_left_clause_that_matches_given( + self._from_obj, l_info.selectable) + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't identify which entity in which to assign the " + "left side of this join. Please use a more specific " + "ON clause.") + + # have an index, means the left side is already present in + # an existing FROM in the self._from_obj tuple + if indexes: + replace_from_obj_index = indexes[0] + + # no index, means we need to add a new element to the + # self._from_obj tuple + + # no from element present, so we will have to add to the + # self._from_obj tuple. Determine if this left side matches up + # with existing mapper entities, in which case we want to apply the + # aliasing / adaptation rules present on that entity if any + if replace_from_obj_index is None and \ + self._entities and hasattr(l_info, 'mapper'): + for idx, ent in enumerate(self._entities): + # TODO: should we be checking for multiple mapper entities + # matching? + if isinstance(ent, _MapperEntity) and ent.corresponds_to(left): + use_entity_index = idx + break + + return replace_from_obj_index, use_entity_index + + def _join_check_and_adapt_right_side( + self, left, right, onclause, prop, create_aliases): + """transform the "right" side of the join as well as the onclause + according to polymorphic mapping translations, aliasing on the query + or on the join, special cases where the right and left side have + overlapping tables. + + """ l_info = inspect(left) r_info = inspect(right) @@ -2341,34 +2550,10 @@ class Query(object): "Can't join table/selectable '%s' to itself" % l_info.selectable) - right, onclause = self._prepare_right_side( - r_info, right, onclause, - create_aliases, - prop, overlap) - - # if joining on a MapperProperty path, - # track the path to prevent redundant joins - if not create_aliases and prop: - self._update_joinpoint({ - '_joinpoint_entity': right, - 'prev': ((left, right, prop.key), self._joinpoint) - }) - else: - self._joinpoint = {'_joinpoint_entity': right} - - self._join_to_left(l_info, left, right, onclause, outerjoin, full) - - def _prepare_right_side(self, r_info, right, onclause, create_aliases, - prop, overlap): - info = r_info - right_mapper, right_selectable, right_is_aliased = \ - getattr(info, 'mapper', None), \ - info.selectable, \ - getattr(info, 'is_aliased_class', False) - - if right_mapper: - self._join_entities += (info, ) + getattr(r_info, 'mapper', None), \ + r_info.selectable, \ + getattr(r_info, 'is_aliased_class', False) if right_mapper and prop and \ not right_mapper.common_parent(prop.mapper): @@ -2377,6 +2562,11 @@ class Query(object): "the right side of join condition %s" % (right, onclause) ) + # _join_entities is used as a hint for single-table inheritance + # purposes at the moment + if hasattr(r_info, 'mapper'): + self._join_entities += (r_info, ) + if not right_mapper and prop: right_mapper = prop.mapper @@ -2408,9 +2598,8 @@ class Query(object): ( right_mapper.with_polymorphic and isinstance( right_mapper._with_polymorphic_selectable, - expression.Alias) - or - overlap # test for overlap: + expression.Alias) or overlap + # test for overlap: # orm/inheritance/relationships.py # SelfReferentialM2MTest ) @@ -2447,52 +2636,17 @@ class Query(object): ) ) - return right, onclause - - def _join_to_left(self, l_info, left, right, onclause, outerjoin, full): - info = l_info - left_mapper = getattr(info, 'mapper', None) - left_selectable = info.selectable - - if self._from_obj: - replace_clause_index, clause = sql_util.find_join_source( - self._from_obj, - left_selectable) - if clause is not None: - try: - clause = orm_join(clause, - right, - onclause, isouter=outerjoin, full=full) - except sa_exc.ArgumentError as ae: - raise sa_exc.InvalidRequestError( - "Could not find a FROM clause to join from. " - "Tried joining to %s, but got: %s" % (right, ae)) - - self._from_obj = \ - self._from_obj[:replace_clause_index] + \ - (clause, ) + \ - self._from_obj[replace_clause_index + 1:] - return - - if left_mapper: - for ent in self._entities: - if ent.corresponds_to(left): - clause = ent.selectable - break - else: - clause = left + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if not create_aliases and prop: + self._update_joinpoint({ + '_joinpoint_entity': right, + 'prev': ((left, right, prop.key), self._joinpoint) + }) else: - clause = left_selectable + self._joinpoint = {'_joinpoint_entity': right} - assert clause is not None - try: - clause = orm_join( - clause, right, onclause, isouter=outerjoin, full=full) - except sa_exc.ArgumentError as ae: - raise sa_exc.InvalidRequestError( - "Could not find a FROM clause to join from. " - "Tried joining to %s, but got: %s" % (right, ae)) - self._from_obj = self._from_obj + (clause,) + return right, inspect(right), onclause def _reset_joinpoint(self): self._joinpoint = self._joinpath @@ -4049,8 +4203,8 @@ class _BundleEntity(_QueryEntity): return None def corresponds_to(self, entity): - # TODO: this seems to have no effect for - # _ColumnEntity either + # TODO: we might be able to implement this but for now + # we are working around it return False @property @@ -4226,8 +4380,6 @@ class _ColumnEntity(_QueryEntity): self.froms.add(ext_info.selectable) def corresponds_to(self, entity): - # TODO: just returning False here, - # no tests fail if self.entity_zero is None: return False elif _is_aliased_class(entity): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 9d83952d3..47791f9b9 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1544,14 +1544,24 @@ class JoinedLoader(AbstractRelationshipLoader): if entity not in context.eager_joins and \ not should_nest_selectable and \ context.from_clause: - index, clause = sql_util.find_join_source( + indexes = sql_util.find_left_clause_that_matches_given( context.from_clause, entity.selectable) - if clause is not None: + + if len(indexes) > 1: + # for the eager load case, I can't reproduce this right + # now. For query.join() I can. + raise sa_exc.InvalidRequestError( + "Can't identify which entity in which to joined eager " + "load from. Please use an exact match when specifying " + "the join path.") + + if indexes: + clause = context.from_clause[indexes[0]] # join to an existing FROM clause on the query. # key it to its list index in the eager_joins dict. # Query._compile_context will adapt as needed and # append to the FROM clause of the select(). - entity_key, default_towrap = index, clause + entity_key, default_towrap = indexes[0], clause if entity_key is None: entity_key, default_towrap = entity, entity.selectable diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 64886b326..f64f152c4 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -987,6 +987,19 @@ class Join(FromClause): return and_(*crit) @classmethod + def _can_join(cls, left, right, consider_as_foreign_keys=None): + if isinstance(left, Join): + left_right = left.right + else: + left_right = None + + constraints = cls._joincond_scan_left_right( + a=left, b=right, a_subset=left_right, + consider_as_foreign_keys=consider_as_foreign_keys) + + return bool(constraints) + + @classmethod def _joincond_scan_left_right( cls, a, a_subset, b, consider_as_foreign_keys): constraints = collections.defaultdict(list) @@ -1059,6 +1072,7 @@ class Join(FromClause): "Please specify the 'onclause' of this " "join explicitly." % (a.description, b.description)) + def select(self, whereclause=None, **kwargs): r"""Create a :class:`.Select` from this :class:`.Join`. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f15cc95ed..12cfe09d1 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -49,12 +49,97 @@ def find_join_source(clauses, join_to): """ selectables = list(_from_objects(join_to)) + idx = [] for i, f in enumerate(clauses): for s in selectables: if f.is_derived_from(s): - return i, f + idx.append(i) + return idx + + +def find_left_clause_that_matches_given(clauses, join_from): + """Given a list of FROM clauses and a selectable, + return the indexes from the list of + clauses which is derived from the selectable. + + """ + + selectables = list(_from_objects(join_from)) + liberal_idx = [] + for i, f in enumerate(clauses): + for s in selectables: + # basic check, if f is derived from s. + # this can be joins containing a table, or an aliased table + # or select statement matching to a table. This check + # will match a table to a selectable that is adapted from + # that table. With Query, this suits the case where a join + # is being made to an adapted entity + if f.is_derived_from(s): + liberal_idx.append(i) + break + + # in an extremely small set of use cases, a join is being made where + # there are multiple FROM clauses where our target table is represented + # in more than one, such as embedded or similar. in this case, do + # another pass where we try to get a more exact match where we aren't + # looking at adaption relationships. + if len(liberal_idx) > 1: + conservative_idx = [] + for idx in liberal_idx: + f = clauses[idx] + for s in selectables: + if set(surface_selectables(f)).\ + intersection(surface_selectables(s)): + conservative_idx.append(idx) + break + if conservative_idx: + return conservative_idx + + return liberal_idx + + +def find_left_clause_to_join_from(clauses, join_to, onclause): + """Given a list of FROM clauses, a selectable, + and optional ON clause, return a list of integer indexes from the + clauses list indicating the clauses that can be joined from. + + The presense of an "onclause" indicates that at least one clause can + definitely be joined from; if the list of clauses is of length one + and the onclause is given, returns that index. If the list of clauses + is more than length one, and the onclause is given, attempts to locate + which clauses contain the same columns. + + """ + idx = [] + selectables = set(_from_objects(join_to)) + + # if we are given more than one target clause to join + # from, use the onclause to provide a more specific answer. + # otherwise, don't try to limit, after all, "ON TRUE" is a valid + # on clause + if len(clauses) > 1 and onclause is not None: + resolve_ambiguity = True + cols_in_onclause = _find_columns(onclause) + else: + resolve_ambiguity = False + cols_in_onclause = None + + for i, f in enumerate(clauses): + for s in selectables.difference([f]): + if resolve_ambiguity: + if set(f.c).union(s.c).issuperset(cols_in_onclause): + idx.append(i) + break + elif Join._can_join(f, s) or onclause is not None: + idx.append(i) + break + + # onclause was given and none of them resolved, so assume + # all indexes can match + if not idx and onclause is not None: + return range(len(clauses)) else: - return None, None + return idx def visit_binary_product(fn, expr): |
