diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 164 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 12 |
4 files changed, 167 insertions, 49 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f9a3863da..d8ad7c3fa 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1584,18 +1584,35 @@ class ClauseElement(Visitable): return id(self) def _annotate(self, values): - """return a copy of this ClauseElement with the given annotations - dictionary. + """return a copy of this ClauseElement with annotations + updated by the given dictionary. """ return sqlutil.Annotated(self, values) - def _deannotate(self): - """return a copy of this ClauseElement with an empty annotations - dictionary. + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. """ - return self._clone() + return sqlutil.Annotated(self, values) + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone: + # clone is used when we are also copying + # the expression for a deep deannotation + return self._clone() + else: + # if no clone, since we have no annotations we return + # self + return self def unique_params(self, *optionaldict, **kwargs): """Return a copy with :func:`bindparam()` elements replaced. @@ -2195,7 +2212,7 @@ class ColumnElement(ClauseElement, _CompareMixin): for oth in to_compare: if use_proxies and self.shares_lineage(oth): return True - elif oth is self: + elif hash(oth) == hash(self): return True else: return False @@ -3403,6 +3420,10 @@ class _BinaryExpression(ColumnElement): raise TypeError("Boolean value of this clause is not defined") @property + def is_comparison(self): + return operators.is_comparison(self.operator) + + @property def _from_objects(self): return self.left._from_objects + self.right._from_objects diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 89f0aaee1..b86b50db4 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -521,6 +521,11 @@ def nullslast_op(a): _commutative = set([eq, ne, add, mul]) +_comparison = set([eq, ne, lt, gt, ge, le]) + +def is_comparison(op): + return op in _comparison + def is_commutative(op): return op in _commutative diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 8d2b5ecfd..cb8359048 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -62,6 +62,65 @@ def find_join_source(clauses, join_to): else: return None, None + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a <eq> q + a <eq> e + a <eq> f + b <eq> q + b <eq> e + b <eq> f + j <eq> r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + def visit(element): + if isinstance(element, (expression._ScalarSelect)): + # we dont want to dig into correlated subqueries, + # those are just column elements by themselves + yield element + elif element.__visit_name__ == 'binary' and \ + operators.is_comparison(element.operator): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, expression.ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + list(visit(expr)) + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False, include_crud=False): @@ -225,7 +284,10 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) -def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): + +def join_condition(a, b, ignore_nonexistent_tables=False, + a_subset=None, + consider_as_foreign_keys=None): """create a join condition between two tables or selectables. e.g.:: @@ -261,6 +323,9 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): for fk in sorted( b.foreign_keys, key=lambda fk:fk.parent._creation_order): + if consider_as_foreign_keys is not None and \ + fk.parent not in consider_as_foreign_keys: + continue try: col = fk.get_referent(left) except exc.NoReferenceError, nrte: @@ -276,6 +341,9 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): for fk in sorted( left.foreign_keys, key=lambda fk:fk.parent._creation_order): + if consider_as_foreign_keys is not None and \ + fk.parent not in consider_as_foreign_keys: + continue try: col = fk.get_referent(b) except exc.NoReferenceError, nrte: @@ -298,11 +366,11 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): "subquery using alias()?" else: hint = "" - raise exc.ArgumentError( + raise exc.NoForeignKeysError( "Can't find any foreign key relationships " "between '%s' and '%s'.%s" % (a.description, b.description, hint)) elif len(constraints) > 1: - raise exc.ArgumentError( + raise exc.AmbiguousForeignKeysError( "Can't determine join between '%s' and '%s'; " "tables have more than one foreign key " "constraint relationship between them. " @@ -356,13 +424,22 @@ class Annotated(object): def _annotate(self, values): _values = self._annotations.copy() _values.update(values) + return self._with_annotations(_values) + + def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() - clone._annotations = _values + clone._annotations = values return clone - def _deannotate(self): - return self.__element + def _deannotate(self, values=None, clone=True): + if values is None: + return self.__element + else: + _values = self._annotations.copy() + for v in values: + _values.pop(v, None) + return self._with_annotations(_values) def _compiler_dispatch(self, visitor, **kw): return self.__element.__class__._compiler_dispatch(self, visitor, **kw) @@ -410,14 +487,8 @@ def _deep_annotate(element, annotations, exclude=None): Elements within the exclude collection will be cloned but not annotated. """ - cloned = util.column_dict() - def clone(elem): - # check if element is present in the exclude list. - # take into account proxying relationships. - if elem in cloned: - return cloned[elem] - elif exclude and \ + if exclude and \ hasattr(elem, 'proxy_set') and \ elem.proxy_set.intersection(exclude): newelem = elem._clone() @@ -426,24 +497,32 @@ def _deep_annotate(element, annotations, exclude=None): else: newelem = elem newelem._copy_internals(clone=clone) - cloned[elem] = newelem return newelem if element is not None: element = clone(element) return element -def _deep_deannotate(element): - """Deep copy the given element, removing all annotations.""" +def _deep_deannotate(element, values=None): + """Deep copy the given element, removing annotations.""" cloned = util.column_dict() def clone(elem): - if elem not in cloned: - newelem = elem._deannotate() + # if a values dict is given, + # the elem must be cloned each time it appears, + # as there may be different annotations in source + # elements that are remaining. if totally + # removing all annotations, can assume the same + # slate... + if values or elem not in cloned: + newelem = elem._deannotate(values=values, clone=True) newelem._copy_internals(clone=clone) - cloned[elem] = newelem - return cloned[elem] + if not values: + cloned[elem] = newelem + return newelem + else: + return cloned[elem] if element is not None: element = clone(element) @@ -547,6 +626,10 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, "'consider_as_foreign_keys' or " "'consider_as_referenced_keys'") + def col_is(a, b): + #return a is b + return a.compare(b) + def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return @@ -556,20 +639,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, if consider_as_foreign_keys: if binary.left in consider_as_foreign_keys and \ - (binary.right is binary.left or + (col_is(binary.right, binary.left) or binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) elif binary.right in consider_as_foreign_keys and \ - (binary.left is binary.right or + (col_is(binary.left, binary.right) or binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: if binary.left in consider_as_referenced_keys and \ - (binary.right is binary.left or + (col_is(binary.right, binary.left) or binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) elif binary.right in consider_as_referenced_keys and \ - (binary.left is binary.right or + (col_is(binary.left, binary.right) or binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: @@ -681,11 +764,22 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): s.c.col1 == table2.c.col1 """ - def __init__(self, selectable, equivalents=None, include=None, exclude=None, adapt_on_names=False): + def __init__(self, selectable, equivalents=None, + include=None, exclude=None, + include_fn=None, exclude_fn=None, + adapt_on_names=False): self.__traverse_options__ = {'stop_on':[selectable]} self.selectable = selectable - self.include = include - self.exclude = exclude + if include: + assert not include_fn + self.include_fn = lambda e: e in include + else: + self.include_fn = include_fn + if exclude: + assert not exclude_fn + self.exclude_fn = lambda e: e in exclude + else: + self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names @@ -705,19 +799,17 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return newcol def replace(self, col): - if isinstance(col, expression.FromClause): - if self.selectable.is_derived_from(col): + if isinstance(col, expression.FromClause) and \ + self.selectable.is_derived_from(col): return self.selectable - - if not isinstance(col, expression.ColumnElement): + elif not isinstance(col, expression.ColumnElement): return None - - if self.include and col not in self.include: + elif self.include_fn and not self.include_fn(col): return None - elif self.exclude and col in self.exclude: + elif self.exclude_fn and self.exclude_fn(col): return None - - return self._corresponding_column(col, True) + else: + return self._corresponding_column(col, True) class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 5354fbcbb..8a06982fc 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -242,13 +242,13 @@ def cloned_traverse(obj, opts, visitors): if elem in stop_on: return elem else: - if elem not in cloned: - cloned[elem] = newelem = elem._clone() + if id(elem) not in cloned: + cloned[id(elem)] = newelem = elem._clone() newelem._copy_internals(clone=clone) meth = visitors.get(newelem.__visit_name__, None) if meth: meth(newelem) - return cloned[elem] + return cloned[id(elem)] if obj is not None: obj = clone(obj) @@ -260,16 +260,16 @@ def replacement_traverse(obj, opts, replace): replacement by a given replacement function.""" cloned = util.column_dict() - stop_on = util.column_set(opts.get('stop_on', [])) + stop_on = util.column_set([id(x) for x in opts.get('stop_on', [])]) def clone(elem, **kw): - if elem in stop_on or \ + if id(elem) in stop_on or \ 'no_replacement_traverse' in elem._annotations: return elem else: newelem = replace(elem) if newelem is not None: - stop_on.add(newelem) + stop_on.add(id(newelem)) return newelem else: if elem not in cloned: |
