summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2018-11-15 15:05:11 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2018-11-15 15:05:11 +0000
commitbfd6f76a7dd48d06a2d55e1b6dba191817e144ce (patch)
tree709ba9c218001a6efeb2ac04fda4e17d143027fd /lib/sqlalchemy/sql
parent0a07fd99dbb8122f8b0786d693506c849db58d9e (diff)
parent7dcfd1e019e1c0ebceba06d6684f5bf64a2efb71 (diff)
downloadsqlalchemy-bfd6f76a7dd48d06a2d55e1b6dba191817e144ce.tar.gz
Merge "Allow join() to pick the best candidate from multiple froms/entities"
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/selectable.py14
-rw-r--r--lib/sqlalchemy/sql/util.py89
2 files changed, 101 insertions, 2 deletions
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 64886b326..f64f152c4 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -987,6 +987,19 @@ class Join(FromClause):
return and_(*crit)
@classmethod
+ def _can_join(cls, left, right, consider_as_foreign_keys=None):
+ if isinstance(left, Join):
+ left_right = left.right
+ else:
+ left_right = None
+
+ constraints = cls._joincond_scan_left_right(
+ a=left, b=right, a_subset=left_right,
+ consider_as_foreign_keys=consider_as_foreign_keys)
+
+ return bool(constraints)
+
+ @classmethod
def _joincond_scan_left_right(
cls, a, a_subset, b, consider_as_foreign_keys):
constraints = collections.defaultdict(list)
@@ -1059,6 +1072,7 @@ class Join(FromClause):
"Please specify the 'onclause' of this "
"join explicitly." % (a.description, b.description))
+
def select(self, whereclause=None, **kwargs):
r"""Create a :class:`.Select` from this :class:`.Join`.
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index f15cc95ed..12cfe09d1 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -49,12 +49,97 @@ def find_join_source(clauses, join_to):
"""
selectables = list(_from_objects(join_to))
+ idx = []
for i, f in enumerate(clauses):
for s in selectables:
if f.is_derived_from(s):
- return i, f
+ idx.append(i)
+ return idx
+
+
+def find_left_clause_that_matches_given(clauses, join_from):
+ """Given a list of FROM clauses and a selectable,
+ return the indexes from the list of
+ clauses which is derived from the selectable.
+
+ """
+
+ selectables = list(_from_objects(join_from))
+ liberal_idx = []
+ for i, f in enumerate(clauses):
+ for s in selectables:
+ # basic check, if f is derived from s.
+ # this can be joins containing a table, or an aliased table
+ # or select statement matching to a table. This check
+ # will match a table to a selectable that is adapted from
+ # that table. With Query, this suits the case where a join
+ # is being made to an adapted entity
+ if f.is_derived_from(s):
+ liberal_idx.append(i)
+ break
+
+ # in an extremely small set of use cases, a join is being made where
+ # there are multiple FROM clauses where our target table is represented
+ # in more than one, such as embedded or similar. in this case, do
+ # another pass where we try to get a more exact match where we aren't
+ # looking at adaption relationships.
+ if len(liberal_idx) > 1:
+ conservative_idx = []
+ for idx in liberal_idx:
+ f = clauses[idx]
+ for s in selectables:
+ if set(surface_selectables(f)).\
+ intersection(surface_selectables(s)):
+ conservative_idx.append(idx)
+ break
+ if conservative_idx:
+ return conservative_idx
+
+ return liberal_idx
+
+
+def find_left_clause_to_join_from(clauses, join_to, onclause):
+ """Given a list of FROM clauses, a selectable,
+ and optional ON clause, return a list of integer indexes from the
+ clauses list indicating the clauses that can be joined from.
+
+ The presense of an "onclause" indicates that at least one clause can
+ definitely be joined from; if the list of clauses is of length one
+ and the onclause is given, returns that index. If the list of clauses
+ is more than length one, and the onclause is given, attempts to locate
+ which clauses contain the same columns.
+
+ """
+ idx = []
+ selectables = set(_from_objects(join_to))
+
+ # if we are given more than one target clause to join
+ # from, use the onclause to provide a more specific answer.
+ # otherwise, don't try to limit, after all, "ON TRUE" is a valid
+ # on clause
+ if len(clauses) > 1 and onclause is not None:
+ resolve_ambiguity = True
+ cols_in_onclause = _find_columns(onclause)
+ else:
+ resolve_ambiguity = False
+ cols_in_onclause = None
+
+ for i, f in enumerate(clauses):
+ for s in selectables.difference([f]):
+ if resolve_ambiguity:
+ if set(f.c).union(s.c).issuperset(cols_in_onclause):
+ idx.append(i)
+ break
+ elif Join._can_join(f, s) or onclause is not None:
+ idx.append(i)
+ break
+
+ # onclause was given and none of them resolved, so assume
+ # all indexes can match
+ if not idx and onclause is not None:
+ return range(len(clauses))
else:
- return None, None
+ return idx
def visit_binary_product(fn, expr):