diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2018-11-15 15:05:11 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2018-11-15 15:05:11 +0000 |
| commit | bfd6f76a7dd48d06a2d55e1b6dba191817e144ce (patch) | |
| tree | 709ba9c218001a6efeb2ac04fda4e17d143027fd /lib/sqlalchemy | |
| parent | 0a07fd99dbb8122f8b0786d693506c849db58d9e (diff) | |
| parent | 7dcfd1e019e1c0ebceba06d6684f5bf64a2efb71 (diff) | |
| download | sqlalchemy-bfd6f76a7dd48d06a2d55e1b6dba191817e144ce.tar.gz | |
Merge "Allow join() to pick the best candidate from multiple froms/entities"
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 426 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 89 |
4 files changed, 403 insertions, 142 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index cace2e54a..93b9a85be 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2182,28 +2182,37 @@ class Query(object): # convert to a tuple. keys = (keys,) + # Query.join() accepts a list of join paths all at once. + # step one is to iterate through these paths and determine the + # intent of each path individually. as we encounter a path token, + # we add a new ORMJoin construct to the self._from_obj tuple, + # either by adding a new element to it, or by replacing an existing + # element with a new ORMJoin. keylist = util.to_list(keys) for idx, arg1 in enumerate(keylist): if isinstance(arg1, tuple): # "tuple" form of join, multiple # tuples are accepted as well. The simpler - # "2-arg" form is preferred. May deprecate - # the "tuple" usage. + # "2-arg" form is preferred. arg1, arg2 = arg1 else: arg2 = None # determine onclause/right_entity. there # is a little bit of legacy behavior still at work here - # which means they might be in either order. may possibly - # lock this down to (right_entity, onclause) in 0.6. + # which means they might be in either order. if isinstance( arg1, (interfaces.PropComparator, util.string_types)): - right_entity, onclause = arg2, arg1 + right, onclause = arg2, arg1 else: - right_entity, onclause = arg1, arg2 + right, onclause = arg1, arg2 - left_entity = prop = None + if onclause is None: + r_info = inspect(right) + if not r_info.is_selectable and not hasattr(r_info, 'mapper'): + raise sa_exc.ArgumentError( + "Expected mapped entity or " + "selectable/table as join target") if isinstance(onclause, interfaces.PropComparator): of_type = getattr(onclause, '_of_type', None) @@ -2211,44 +2220,46 @@ class Query(object): of_type = None if isinstance(onclause, util.string_types): - left_entity = self._joinpoint_zero() - - descriptor = _entity_descriptor(left_entity, onclause) - onclause = descriptor + # string given, e.g. query(Foo).join("bar"). + # we look to the left entity or what we last joined + # towards + onclause = _entity_descriptor(self._joinpoint_zero(), onclause) # check for q.join(Class.propname, from_joinpoint=True) - # and Class is that of the current joinpoint + # and Class corresponds at the mapper level to the current + # joinpoint. this match intentionally looks for a non-aliased + # class-bound descriptor as the onclause and if it matches the + # current joinpoint at the mapper level, it's used. This + # is a very old use case that is intended to make it easier + # to work with the aliased=True flag, which is also something + # that probably shouldn't exist on join() due to its high + # complexity/usefulness ratio elif from_joinpoint and \ isinstance(onclause, interfaces.PropComparator): - left_entity = onclause._parententity + jp0 = self._joinpoint_zero() + info = inspect(jp0) - info = inspect(self._joinpoint_zero()) - left_mapper, left_selectable, left_is_aliased = \ - getattr(info, 'mapper', None), \ - info.selectable, \ - getattr(info, 'is_aliased_class', None) - - if left_mapper is left_entity: - left_entity = self._joinpoint_zero() - descriptor = _entity_descriptor(left_entity, - onclause.key) - onclause = descriptor + if getattr(info, 'mapper', None) is onclause._parententity: + onclause = _entity_descriptor(jp0, onclause.key) if isinstance(onclause, interfaces.PropComparator): - if right_entity is None: + # descriptor/property given (or determined); this tells + # us explicitly what the expected "left" side of the join is. + if right is None: if of_type: - right_entity = of_type + right = of_type else: - right_entity = onclause.property.mapper + right = onclause.property.mapper + + left = onclause._parententity - left_entity = onclause._parententity + alias = self._polymorphic_adapters.get(left, None) - alias = self._polymorphic_adapters.get(left_entity, None) # could be None or could be ColumnAdapter also if isinstance(alias, ORMAdapter) and \ - alias.mapper.isa(left_entity): - left_entity = alias.aliased_class - onclause = getattr(left_entity, onclause.key) + alias.mapper.isa(left): + left = alias.aliased_class + onclause = getattr(left, onclause.key) prop = onclause.property if not isinstance(onclause, attributes.QueryableAttribute): @@ -2257,7 +2268,7 @@ class Query(object): if not create_aliases: # check for this path already present. # don't render in that case. - edge = (left_entity, right_entity, prop.key) + edge = (left, right, prop.key) if edge in self._joinpoint: # The child's prev reference might be stale -- # it could point to a parent older than the @@ -2270,50 +2281,248 @@ class Query(object): jp['prev'] = (edge, self._joinpoint) self._update_joinpoint(jp) + # warn only on the last element of the list if idx == len(keylist) - 1: util.warn( "Pathed join target %s has already " "been joined to; skipping" % prop) continue + else: + # no descriptor/property given; we will need to figure out + # what the effective "left" side is + prop = left = None - elif onclause is not None and right_entity is None: - # TODO: no coverage here - raise NotImplementedError("query.join(a==b) not supported.") - + # figure out the final "left" and "right" sides and create an + # ORMJoin to add to our _from_obj tuple self._join_left_to_right( - left_entity, - right_entity, onclause, - outerjoin, full, create_aliases, prop) + left, right, onclause, prop, create_aliases, + outerjoin, full + ) - def _join_left_to_right(self, left, right, - onclause, outerjoin, full, create_aliases, prop): - """append a JOIN to the query's from clause.""" + def _join_left_to_right( + self, left, right, onclause, prop, + create_aliases, outerjoin, full): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ self._polymorphic_adapters = self._polymorphic_adapters.copy() if left is None: - if self._from_obj: - left = self._from_obj[0] - elif self._entities: - left = self._entities[0].entity_zero_or_selectable + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + left, replace_from_obj_index, use_entity_index = \ + self._join_determine_implicit_left_side(left, right, onclause) + else: + # left is given via a relationship/name. Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + assert prop is not None + replace_from_obj_index, use_entity_index = \ + self._join_place_explicit_left_side(left) + + # this should never happen because we would not have found a place + # to join on + assert left is not right or create_aliases + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the the new effective "right" side + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, create_aliases, + ) - if left is None: - if self._entities: - problem = "Don't know how to join from %s" % self._entities[0] + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self._from_obj[replace_from_obj_index] + + self._from_obj = ( + self._from_obj[:replace_from_obj_index] + + (orm_join( + left_clause, right, + onclause, isouter=outerjoin, full=full), ) + + self._from_obj[replace_from_obj_index + 1:]) + else: + # add a new element to the self._from_obj list + + if use_entity_index is not None: + # why doesn't this work as .entity_zero_or_selectable? + left_clause = self._entities[use_entity_index].selectable + else: + left_clause = left + + self._from_obj = self._from_obj + ( + orm_join( + left_clause, right, onclause, + isouter=outerjoin, full=full), + ) + + def _join_determine_implicit_left_side(self, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self._from_obj: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + indexes = sql_util.find_left_clause_to_join_from( + self._from_obj, + r_info.selectable, onclause) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self._from_obj[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Try adding an explicit ON clause " + "to help resolve the ambiguity.") else: - problem = "No entities to join from" + raise sa_exc.InvalidRequestError( + "Don't know how to join to %s; please use " + "an ON clause to more clearly establish the left " + "side of this join" % (right, ) + ) + + elif self._entities: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(self._entities): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, _MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + all_clauses = list(potential.keys()) + indexes = sql_util.find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Try adding an explicit ON clause " + "to help resolve the ambiguity.") + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %s; please use " + "an ON clause to more clearly establish the left " + "side of this join" % (right, ) + ) + else: raise sa_exc.InvalidRequestError( - "%s; please use " + "No entities to join from; please use " "select_from() to establish the left " - "entity/selectable of this join" % problem) + "entity/selectable of this join") - if left is right and \ - not create_aliases: - raise sa_exc.InvalidRequestError( - "Can't construct a join from %s to %s, they " - "are the same entity" % - (left, right)) + return left, replace_from_obj_index, use_entity_index + + def _join_place_explicit_left_side(self, left): + """When join conditions express a left side explicitly, determine + where in our existing list of FROM clauses we should join towards, + or if we need to make a new join, and if so is it from one of our + existing entities. + + """ + + # when we are here, it means join() was called with an indicator + # as to an exact left side, which means a path to a + # RelationshipProperty was given, e.g.: + # + # join(RightEntity, LeftEntity.right) + # + # or + # + # join(LeftEntity.right) + # + # as well as string forms: + # + # join(RightEntity, "right") + # + # etc. + # + + replace_from_obj_index = use_entity_index = None + + l_info = inspect(left) + if self._from_obj: + indexes = sql_util.find_left_clause_that_matches_given( + self._from_obj, l_info.selectable) + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't identify which entity in which to assign the " + "left side of this join. Please use a more specific " + "ON clause.") + + # have an index, means the left side is already present in + # an existing FROM in the self._from_obj tuple + if indexes: + replace_from_obj_index = indexes[0] + + # no index, means we need to add a new element to the + # self._from_obj tuple + + # no from element present, so we will have to add to the + # self._from_obj tuple. Determine if this left side matches up + # with existing mapper entities, in which case we want to apply the + # aliasing / adaptation rules present on that entity if any + if replace_from_obj_index is None and \ + self._entities and hasattr(l_info, 'mapper'): + for idx, ent in enumerate(self._entities): + # TODO: should we be checking for multiple mapper entities + # matching? + if isinstance(ent, _MapperEntity) and ent.corresponds_to(left): + use_entity_index = idx + break + + return replace_from_obj_index, use_entity_index + + def _join_check_and_adapt_right_side( + self, left, right, onclause, prop, create_aliases): + """transform the "right" side of the join as well as the onclause + according to polymorphic mapping translations, aliasing on the query + or on the join, special cases where the right and left side have + overlapping tables. + + """ l_info = inspect(left) r_info = inspect(right) @@ -2341,34 +2550,10 @@ class Query(object): "Can't join table/selectable '%s' to itself" % l_info.selectable) - right, onclause = self._prepare_right_side( - r_info, right, onclause, - create_aliases, - prop, overlap) - - # if joining on a MapperProperty path, - # track the path to prevent redundant joins - if not create_aliases and prop: - self._update_joinpoint({ - '_joinpoint_entity': right, - 'prev': ((left, right, prop.key), self._joinpoint) - }) - else: - self._joinpoint = {'_joinpoint_entity': right} - - self._join_to_left(l_info, left, right, onclause, outerjoin, full) - - def _prepare_right_side(self, r_info, right, onclause, create_aliases, - prop, overlap): - info = r_info - right_mapper, right_selectable, right_is_aliased = \ - getattr(info, 'mapper', None), \ - info.selectable, \ - getattr(info, 'is_aliased_class', False) - - if right_mapper: - self._join_entities += (info, ) + getattr(r_info, 'mapper', None), \ + r_info.selectable, \ + getattr(r_info, 'is_aliased_class', False) if right_mapper and prop and \ not right_mapper.common_parent(prop.mapper): @@ -2377,6 +2562,11 @@ class Query(object): "the right side of join condition %s" % (right, onclause) ) + # _join_entities is used as a hint for single-table inheritance + # purposes at the moment + if hasattr(r_info, 'mapper'): + self._join_entities += (r_info, ) + if not right_mapper and prop: right_mapper = prop.mapper @@ -2408,9 +2598,8 @@ class Query(object): ( right_mapper.with_polymorphic and isinstance( right_mapper._with_polymorphic_selectable, - expression.Alias) - or - overlap # test for overlap: + expression.Alias) or overlap + # test for overlap: # orm/inheritance/relationships.py # SelfReferentialM2MTest ) @@ -2447,52 +2636,17 @@ class Query(object): ) ) - return right, onclause - - def _join_to_left(self, l_info, left, right, onclause, outerjoin, full): - info = l_info - left_mapper = getattr(info, 'mapper', None) - left_selectable = info.selectable - - if self._from_obj: - replace_clause_index, clause = sql_util.find_join_source( - self._from_obj, - left_selectable) - if clause is not None: - try: - clause = orm_join(clause, - right, - onclause, isouter=outerjoin, full=full) - except sa_exc.ArgumentError as ae: - raise sa_exc.InvalidRequestError( - "Could not find a FROM clause to join from. " - "Tried joining to %s, but got: %s" % (right, ae)) - - self._from_obj = \ - self._from_obj[:replace_clause_index] + \ - (clause, ) + \ - self._from_obj[replace_clause_index + 1:] - return - - if left_mapper: - for ent in self._entities: - if ent.corresponds_to(left): - clause = ent.selectable - break - else: - clause = left + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if not create_aliases and prop: + self._update_joinpoint({ + '_joinpoint_entity': right, + 'prev': ((left, right, prop.key), self._joinpoint) + }) else: - clause = left_selectable + self._joinpoint = {'_joinpoint_entity': right} - assert clause is not None - try: - clause = orm_join( - clause, right, onclause, isouter=outerjoin, full=full) - except sa_exc.ArgumentError as ae: - raise sa_exc.InvalidRequestError( - "Could not find a FROM clause to join from. " - "Tried joining to %s, but got: %s" % (right, ae)) - self._from_obj = self._from_obj + (clause,) + return right, inspect(right), onclause def _reset_joinpoint(self): self._joinpoint = self._joinpath @@ -4049,8 +4203,8 @@ class _BundleEntity(_QueryEntity): return None def corresponds_to(self, entity): - # TODO: this seems to have no effect for - # _ColumnEntity either + # TODO: we might be able to implement this but for now + # we are working around it return False @property @@ -4226,8 +4380,6 @@ class _ColumnEntity(_QueryEntity): self.froms.add(ext_info.selectable) def corresponds_to(self, entity): - # TODO: just returning False here, - # no tests fail if self.entity_zero is None: return False elif _is_aliased_class(entity): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 9d83952d3..47791f9b9 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1544,14 +1544,24 @@ class JoinedLoader(AbstractRelationshipLoader): if entity not in context.eager_joins and \ not should_nest_selectable and \ context.from_clause: - index, clause = sql_util.find_join_source( + indexes = sql_util.find_left_clause_that_matches_given( context.from_clause, entity.selectable) - if clause is not None: + + if len(indexes) > 1: + # for the eager load case, I can't reproduce this right + # now. For query.join() I can. + raise sa_exc.InvalidRequestError( + "Can't identify which entity in which to joined eager " + "load from. Please use an exact match when specifying " + "the join path.") + + if indexes: + clause = context.from_clause[indexes[0]] # join to an existing FROM clause on the query. # key it to its list index in the eager_joins dict. # Query._compile_context will adapt as needed and # append to the FROM clause of the select(). - entity_key, default_towrap = index, clause + entity_key, default_towrap = indexes[0], clause if entity_key is None: entity_key, default_towrap = entity, entity.selectable diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 64886b326..f64f152c4 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -987,6 +987,19 @@ class Join(FromClause): return and_(*crit) @classmethod + def _can_join(cls, left, right, consider_as_foreign_keys=None): + if isinstance(left, Join): + left_right = left.right + else: + left_right = None + + constraints = cls._joincond_scan_left_right( + a=left, b=right, a_subset=left_right, + consider_as_foreign_keys=consider_as_foreign_keys) + + return bool(constraints) + + @classmethod def _joincond_scan_left_right( cls, a, a_subset, b, consider_as_foreign_keys): constraints = collections.defaultdict(list) @@ -1059,6 +1072,7 @@ class Join(FromClause): "Please specify the 'onclause' of this " "join explicitly." % (a.description, b.description)) + def select(self, whereclause=None, **kwargs): r"""Create a :class:`.Select` from this :class:`.Join`. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f15cc95ed..12cfe09d1 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -49,12 +49,97 @@ def find_join_source(clauses, join_to): """ selectables = list(_from_objects(join_to)) + idx = [] for i, f in enumerate(clauses): for s in selectables: if f.is_derived_from(s): - return i, f + idx.append(i) + return idx + + +def find_left_clause_that_matches_given(clauses, join_from): + """Given a list of FROM clauses and a selectable, + return the indexes from the list of + clauses which is derived from the selectable. + + """ + + selectables = list(_from_objects(join_from)) + liberal_idx = [] + for i, f in enumerate(clauses): + for s in selectables: + # basic check, if f is derived from s. + # this can be joins containing a table, or an aliased table + # or select statement matching to a table. This check + # will match a table to a selectable that is adapted from + # that table. With Query, this suits the case where a join + # is being made to an adapted entity + if f.is_derived_from(s): + liberal_idx.append(i) + break + + # in an extremely small set of use cases, a join is being made where + # there are multiple FROM clauses where our target table is represented + # in more than one, such as embedded or similar. in this case, do + # another pass where we try to get a more exact match where we aren't + # looking at adaption relationships. + if len(liberal_idx) > 1: + conservative_idx = [] + for idx in liberal_idx: + f = clauses[idx] + for s in selectables: + if set(surface_selectables(f)).\ + intersection(surface_selectables(s)): + conservative_idx.append(idx) + break + if conservative_idx: + return conservative_idx + + return liberal_idx + + +def find_left_clause_to_join_from(clauses, join_to, onclause): + """Given a list of FROM clauses, a selectable, + and optional ON clause, return a list of integer indexes from the + clauses list indicating the clauses that can be joined from. + + The presense of an "onclause" indicates that at least one clause can + definitely be joined from; if the list of clauses is of length one + and the onclause is given, returns that index. If the list of clauses + is more than length one, and the onclause is given, attempts to locate + which clauses contain the same columns. + + """ + idx = [] + selectables = set(_from_objects(join_to)) + + # if we are given more than one target clause to join + # from, use the onclause to provide a more specific answer. + # otherwise, don't try to limit, after all, "ON TRUE" is a valid + # on clause + if len(clauses) > 1 and onclause is not None: + resolve_ambiguity = True + cols_in_onclause = _find_columns(onclause) + else: + resolve_ambiguity = False + cols_in_onclause = None + + for i, f in enumerate(clauses): + for s in selectables.difference([f]): + if resolve_ambiguity: + if set(f.c).union(s.c).issuperset(cols_in_onclause): + idx.append(i) + break + elif Join._can_join(f, s) or onclause is not None: + idx.append(i) + break + + # onclause was given and none of them resolved, so assume + # all indexes can match + if not idx and onclause is not None: + return range(len(clauses)) else: - return None, None + return idx def visit_binary_product(fn, expr): |
