summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2018-11-12 18:27:34 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2018-11-14 21:35:15 -0500
commit7dcfd1e019e1c0ebceba06d6684f5bf64a2efb71 (patch)
tree35a6a875d1cee208b22884919359128b078eaae6 /lib/sqlalchemy/sql/util.py
parenta698bdbc5716201804ddedde6a0fc5ab33d43300 (diff)
downloadsqlalchemy-7dcfd1e019e1c0ebceba06d6684f5bf64a2efb71.tar.gz
Allow join() to pick the best candidate from multiple froms/entities
Refactored :meth:`.Query.join` to further clarify the individual components of structuring the join. This refactor adds the ability for :meth:`.Query.join` to determine the most appropriate "left" side of the join when there is more than one element in the FROM list or the query is against multiple entities. In particular this targets the regression we saw in :ticket:`4363` but is also of general use. The codepaths within :meth:`.Query.join` are now easier to follow and the error cases are decided more specifically at an earlier point in the operation. Fixes: #4365 Change-Id: I403f451243904a020ceab4c3f94bead550c7b2d5
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py89
1 files changed, 87 insertions, 2 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index f15cc95ed..12cfe09d1 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -49,12 +49,97 @@ def find_join_source(clauses, join_to):
"""
selectables = list(_from_objects(join_to))
+ idx = []
for i, f in enumerate(clauses):
for s in selectables:
if f.is_derived_from(s):
- return i, f
+ idx.append(i)
+ return idx
+
+
+def find_left_clause_that_matches_given(clauses, join_from):
+ """Given a list of FROM clauses and a selectable,
+ return the indexes from the list of
+ clauses which is derived from the selectable.
+
+ """
+
+ selectables = list(_from_objects(join_from))
+ liberal_idx = []
+ for i, f in enumerate(clauses):
+ for s in selectables:
+ # basic check, if f is derived from s.
+ # this can be joins containing a table, or an aliased table
+ # or select statement matching to a table. This check
+ # will match a table to a selectable that is adapted from
+ # that table. With Query, this suits the case where a join
+ # is being made to an adapted entity
+ if f.is_derived_from(s):
+ liberal_idx.append(i)
+ break
+
+ # in an extremely small set of use cases, a join is being made where
+ # there are multiple FROM clauses where our target table is represented
+ # in more than one, such as embedded or similar. in this case, do
+ # another pass where we try to get a more exact match where we aren't
+ # looking at adaption relationships.
+ if len(liberal_idx) > 1:
+ conservative_idx = []
+ for idx in liberal_idx:
+ f = clauses[idx]
+ for s in selectables:
+ if set(surface_selectables(f)).\
+ intersection(surface_selectables(s)):
+ conservative_idx.append(idx)
+ break
+ if conservative_idx:
+ return conservative_idx
+
+ return liberal_idx
+
+
+def find_left_clause_to_join_from(clauses, join_to, onclause):
+ """Given a list of FROM clauses, a selectable,
+ and optional ON clause, return a list of integer indexes from the
+ clauses list indicating the clauses that can be joined from.
+
+ The presense of an "onclause" indicates that at least one clause can
+ definitely be joined from; if the list of clauses is of length one
+ and the onclause is given, returns that index. If the list of clauses
+ is more than length one, and the onclause is given, attempts to locate
+ which clauses contain the same columns.
+
+ """
+ idx = []
+ selectables = set(_from_objects(join_to))
+
+ # if we are given more than one target clause to join
+ # from, use the onclause to provide a more specific answer.
+ # otherwise, don't try to limit, after all, "ON TRUE" is a valid
+ # on clause
+ if len(clauses) > 1 and onclause is not None:
+ resolve_ambiguity = True
+ cols_in_onclause = _find_columns(onclause)
+ else:
+ resolve_ambiguity = False
+ cols_in_onclause = None
+
+ for i, f in enumerate(clauses):
+ for s in selectables.difference([f]):
+ if resolve_ambiguity:
+ if set(f.c).union(s.c).issuperset(cols_in_onclause):
+ idx.append(i)
+ break
+ elif Join._can_join(f, s) or onclause is not None:
+ idx.append(i)
+ break
+
+ # onclause was given and none of them resolved, so assume
+ # all indexes can match
+ if not idx and onclause is not None:
+ return range(len(clauses))
else:
- return None, None
+ return idx
def visit_binary_product(fn, expr):