From 7dcfd1e019e1c0ebceba06d6684f5bf64a2efb71 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 12 Nov 2018 18:27:34 -0500 Subject: Allow join() to pick the best candidate from multiple froms/entities Refactored :meth:`.Query.join` to further clarify the individual components of structuring the join. This refactor adds the ability for :meth:`.Query.join` to determine the most appropriate "left" side of the join when there is more than one element in the FROM list or the query is against multiple entities. In particular this targets the regression we saw in :ticket:`4363` but is also of general use. The codepaths within :meth:`.Query.join` are now easier to follow and the error cases are decided more specifically at an earlier point in the operation. Fixes: #4365 Change-Id: I403f451243904a020ceab4c3f94bead550c7b2d5 --- lib/sqlalchemy/sql/selectable.py | 14 +++++++ lib/sqlalchemy/sql/util.py | 89 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 2 deletions(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 64886b326..f64f152c4 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -986,6 +986,19 @@ class Join(FromClause): else: return and_(*crit) + @classmethod + def _can_join(cls, left, right, consider_as_foreign_keys=None): + if isinstance(left, Join): + left_right = left.right + else: + left_right = None + + constraints = cls._joincond_scan_left_right( + a=left, b=right, a_subset=left_right, + consider_as_foreign_keys=consider_as_foreign_keys) + + return bool(constraints) + @classmethod def _joincond_scan_left_right( cls, a, a_subset, b, consider_as_foreign_keys): @@ -1059,6 +1072,7 @@ class Join(FromClause): "Please specify the 'onclause' of this " "join explicitly." % (a.description, b.description)) + def select(self, whereclause=None, **kwargs): r"""Create a :class:`.Select` from this :class:`.Join`. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f15cc95ed..12cfe09d1 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -49,12 +49,97 @@ def find_join_source(clauses, join_to): """ selectables = list(_from_objects(join_to)) + idx = [] for i, f in enumerate(clauses): for s in selectables: if f.is_derived_from(s): - return i, f + idx.append(i) + return idx + + +def find_left_clause_that_matches_given(clauses, join_from): + """Given a list of FROM clauses and a selectable, + return the indexes from the list of + clauses which is derived from the selectable. + + """ + + selectables = list(_from_objects(join_from)) + liberal_idx = [] + for i, f in enumerate(clauses): + for s in selectables: + # basic check, if f is derived from s. + # this can be joins containing a table, or an aliased table + # or select statement matching to a table. This check + # will match a table to a selectable that is adapted from + # that table. With Query, this suits the case where a join + # is being made to an adapted entity + if f.is_derived_from(s): + liberal_idx.append(i) + break + + # in an extremely small set of use cases, a join is being made where + # there are multiple FROM clauses where our target table is represented + # in more than one, such as embedded or similar. in this case, do + # another pass where we try to get a more exact match where we aren't + # looking at adaption relationships. + if len(liberal_idx) > 1: + conservative_idx = [] + for idx in liberal_idx: + f = clauses[idx] + for s in selectables: + if set(surface_selectables(f)).\ + intersection(surface_selectables(s)): + conservative_idx.append(idx) + break + if conservative_idx: + return conservative_idx + + return liberal_idx + + +def find_left_clause_to_join_from(clauses, join_to, onclause): + """Given a list of FROM clauses, a selectable, + and optional ON clause, return a list of integer indexes from the + clauses list indicating the clauses that can be joined from. + + The presense of an "onclause" indicates that at least one clause can + definitely be joined from; if the list of clauses is of length one + and the onclause is given, returns that index. If the list of clauses + is more than length one, and the onclause is given, attempts to locate + which clauses contain the same columns. + + """ + idx = [] + selectables = set(_from_objects(join_to)) + + # if we are given more than one target clause to join + # from, use the onclause to provide a more specific answer. + # otherwise, don't try to limit, after all, "ON TRUE" is a valid + # on clause + if len(clauses) > 1 and onclause is not None: + resolve_ambiguity = True + cols_in_onclause = _find_columns(onclause) + else: + resolve_ambiguity = False + cols_in_onclause = None + + for i, f in enumerate(clauses): + for s in selectables.difference([f]): + if resolve_ambiguity: + if set(f.c).union(s.c).issuperset(cols_in_onclause): + idx.append(i) + break + elif Join._can_join(f, s) or onclause is not None: + idx.append(i) + break + + # onclause was given and none of them resolved, so assume + # all indexes can match + if not idx and onclause is not None: + return range(len(clauses)) else: - return None, None + return idx def visit_binary_product(fn, expr): -- cgit v1.2.1