diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 56 |
2 files changed, 55 insertions, 13 deletions
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index c0609dba3..23114cdab 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -372,8 +372,18 @@ class LazyLoader(AbstractRelationLoader): # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() #from sqlalchemy.orm import query - self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere) + self.use_get = not self.uselist and \ + self.mapper._get_clause[0].compare( + self.__lazywhere, + use_proxies=True, + equivalents=self.mapper._equivalent_columns + ) if self.use_get: + for col in self._equated_columns.keys(): + if col in self.mapper._equivalent_columns: + for c in self.mapper._equivalent_columns[col]: + self._equated_columns[c] = self._equated_columns[col] + self.logger.info("%s will use query.get() to optimize instance loads" % self) def init_class_attribute(self, mapper): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 960fc0310..8c6877dbd 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1101,11 +1101,15 @@ class ClauseElement(Visitable): bind._convert_to_unique() return cloned_traverse(self, {}, {'bindparam':visit_bindparam}) - def compare(self, other): + def compare(self, other, **kw): """Compare this ClauseElement to the given ClauseElement. Subclasses should override the default behavior, which is a straight identity comparison. + + **kw are arguments consumed by subclass compare() methods and + may be used to modify the criteria for comparison. + (see :class:`ColumnElement`) """ return self is other @@ -1697,6 +1701,34 @@ class ColumnElement(ClauseElement, _CompareMixin): selectable.columns[name] = co return co + def compare(self, other, use_proxies=False, equivalents=None, **kw): + """Compare this ColumnElement to another. + + Special arguments understood: + + :param use_proxies: when True, consider two columns that + share a common base column as equivalent (i.e. shares_lineage()) + + :param equivalents: a dictionary of columns as keys mapped to sets + of columns. If the given "other" column is present in this dictionary, + if any of the columns in the correponding set() pass the comparison + test, the result is True. This is used to expand the comparison to + other columns that may be known to be equivalent to this one via + foreign key or other criterion. + + """ + to_compare = (other, ) + if equivalents and other in equivalents: + to_compare = equivalents[other].union(to_compare) + + for oth in to_compare: + if use_proxies and self.shares_lineage(oth): + return True + elif oth is self: + return True + else: + return False + @util.memoized_property def anon_label(self): """provides a constant 'anonymous label' for this ColumnElement. @@ -2109,7 +2141,7 @@ class _BindParamClause(ColumnElement): else: return obj.type - def compare(self, other): + def compare(self, other, **kw): """Compare this ``_BindParamClause`` to the given clause. Since ``compare()`` is meant to compare statement syntax, this @@ -2274,16 +2306,16 @@ class ClauseList(ClauseElement): else: return self - def compare(self, other): + def compare(self, other, **kw): """Compare this ``ClauseList`` to the given ``ClauseList``, including a comparison of all the clause items. """ if not isinstance(other, ClauseList) and len(self.clauses) == 1: - return self.clauses[0].compare(other) + return self.clauses[0].compare(other, **kw) elif isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses): for i in range(0, len(self.clauses)): - if not self.clauses[i].compare(other.clauses[i]): + if not self.clauses[i].compare(other.clauses[i], **kw): return False else: return self.operator == other.operator @@ -2473,14 +2505,14 @@ class _UnaryExpression(ColumnElement): def get_children(self, **kwargs): return self.element, - def compare(self, other): + def compare(self, other, **kw): """Compare this ``_UnaryExpression`` against the given ``ClauseElement``.""" return ( isinstance(other, _UnaryExpression) and self.operator == other.operator and self.modifier == other.modifier and - self.element.compare(other.element) + self.element.compare(other.element, **kw) ) def _negate(self): @@ -2528,19 +2560,19 @@ class _BinaryExpression(ColumnElement): def get_children(self, **kwargs): return self.left, self.right - def compare(self, other): + def compare(self, other, **kw): """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" return ( isinstance(other, _BinaryExpression) and self.operator == other.operator and ( - self.left.compare(other.left) and - self.right.compare(other.right) or + self.left.compare(other.left, **kw) and + self.right.compare(other.right, **kw) or ( operators.is_commutative(self.operator) and - self.left.compare(other.right) and - self.right.compare(other.left) + self.left.compare(other.right, **kw) and + self.right.compare(other.left, **kw) ) ) ) |
