diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 44 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 116 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 80 |
5 files changed, 136 insertions, 126 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 494d94bb0..4de438e55 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -13,7 +13,7 @@ mapped attributes. from sqlalchemy import sql, util, log, exc as sa_exc from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ - join_condition + join_condition, _shallow_annotate from sqlalchemy.sql import operators, expression from sqlalchemy.orm import attributes, dependency, mapper, \ object_mapper, strategies, configure_mappers @@ -167,9 +167,6 @@ class ColumnProperty(StrategizedProperty): log.class_logger(ColumnProperty) - - - class RelationshipProperty(StrategizedProperty): """Describes an object property that holds a single item or list of items that correspond to a related database table. @@ -448,7 +445,7 @@ class RelationshipProperty(StrategizedProperty): # should not correlate or otherwise reach out # to anything in the enclosing query. if criterion is not None: - criterion = criterion._annotate({'_halt_adapt': True}) + criterion = criterion._annotate({'no_replacement_traverse': True}) crit = j & criterion @@ -1485,6 +1482,14 @@ class RelationshipProperty(StrategizedProperty): else: aliased = True + # place a barrier on the destination such that + # replacement traversals won't ever dig into it. + # its internal structure remains fixed + # regardless of context. + dest_selectable = _shallow_annotate( + dest_selectable, + {'no_replacement_traverse':True}) + aliased = aliased or (source_selectable is not None) primaryjoin, secondaryjoin, secondary = self.primaryjoin, \ @@ -1508,13 +1513,10 @@ class RelationshipProperty(StrategizedProperty): if secondary is not None: secondary = secondary.alias() primary_aliasizer = ClauseAdapter(secondary) - if dest_selectable is not None: - secondary_aliasizer = \ - ClauseAdapter(dest_selectable, - equivalents=self.mapper._equivalent_columns).\ - chain(primary_aliasizer) - else: - secondary_aliasizer = primary_aliasizer + secondary_aliasizer = \ + ClauseAdapter(dest_selectable, + equivalents=self.mapper._equivalent_columns).\ + chain(primary_aliasizer) if source_selectable is not None: primary_aliasizer = \ ClauseAdapter(secondary).\ @@ -1523,20 +1525,14 @@ class RelationshipProperty(StrategizedProperty): secondaryjoin = \ secondary_aliasizer.traverse(secondaryjoin) else: - if dest_selectable is not None: - primary_aliasizer = ClauseAdapter(dest_selectable, - exclude=self.local_side, - equivalents=self.mapper._equivalent_columns) - if source_selectable is not None: - primary_aliasizer.chain( - ClauseAdapter(source_selectable, - exclude=self.remote_side, - equivalents=self.parent._equivalent_columns)) - elif source_selectable is not None: - primary_aliasizer = \ + primary_aliasizer = ClauseAdapter(dest_selectable, + exclude=self.local_side, + equivalents=self.mapper._equivalent_columns) + if source_selectable is not None: + primary_aliasizer.chain( ClauseAdapter(source_selectable, exclude=self.remote_side, - equivalents=self.parent._equivalent_columns) + equivalents=self.parent._equivalent_columns)) secondary_aliasizer = None primaryjoin = primary_aliasizer.traverse(primaryjoin) target_adapter = secondary_aliasizer or primary_aliasizer diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8d64d69b4..a3b13abb2 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -251,9 +251,6 @@ class Query(object): return clause def replace(elem): - if '_halt_adapt' in elem._annotations: - return elem - for _orm_only, adapter in adapters: # if 'orm only', look for ORM annotations # in the element before adapting. @@ -267,7 +264,7 @@ class Query(object): return visitors.replacement_traverse( clause, - {'column_collections':False}, + {}, replace ) @@ -438,7 +435,9 @@ class Query(object): statement if self._params: stmt = stmt.params(self._params) - return stmt._annotate({'_halt_adapt': True}) + # TODO: there's no tests covering effects of + # the annotation not being there + return stmt._annotate({'no_replacement_traverse': True}) def subquery(self, name=None): """return the full SELECT statement represented by this :class:`.Query`, diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 071bb3c50..fa0586e2d 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1216,7 +1216,7 @@ def _string_or_unprintable(element): except: return "unprintable element %r" % element -def _clone(element): +def _clone(element, **kw): return element._clone() def _expand_cloned(elements): @@ -1522,12 +1522,16 @@ class ClauseElement(Visitable): """ return self is other - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly shallow-copied elements to create a deep copy. + The given clone function should be used, which may be applying + additional transformations to the element (i.e. replacement + traversal, cloned traversal, annotations). + """ pass @@ -2755,8 +2759,8 @@ class _TextClause(Executable, ClauseElement): else: return self - def _copy_internals(self, clone=_clone): - self.bindparams = dict((b.key, clone(b)) + def _copy_internals(self, clone=_clone, **kw): + self.bindparams = dict((b.key, clone(b, **kw)) for b in self.bindparams.values()) def get_children(self, **kwargs): @@ -2846,8 +2850,8 @@ class ClauseList(ClauseElement): else: self.clauses.append(_literal_as_text(clause)) - def _copy_internals(self, clone=_clone): - self.clauses = [clone(clause) for clause in self.clauses] + def _copy_internals(self, clone=_clone, **kw): + self.clauses = [clone(clause, **kw) for clause in self.clauses] def get_children(self, **kwargs): return self.clauses @@ -2947,12 +2951,13 @@ class _Case(ColumnElement): else: self.else_ = None - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): if self.value is not None: - self.value = clone(self.value) - self.whens = [(clone(x), clone(y)) for x, y in self.whens] + self.value = clone(self.value, **kw) + self.whens = [(clone(x, **kw), clone(y, **kw)) + for x, y in self.whens] if self.else_ is not None: - self.else_ = clone(self.else_) + self.else_ = clone(self.else_, **kw) def get_children(self, **kwargs): if self.value is not None: @@ -3028,8 +3033,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): def get_children(self, **kwargs): return self.clause_expr, - def _copy_internals(self, clone=_clone): - self.clause_expr = clone(self.clause_expr) + def _copy_internals(self, clone=_clone, **kw): + self.clause_expr = clone(self.clause_expr, **kw) self._reset_exported() util.reset_memoized(self, 'clauses') @@ -3120,9 +3125,9 @@ class _Cast(ColumnElement): self.clause = _literal_as_binds(clause, None) self.typeclause = _TypeClause(self.type) - def _copy_internals(self, clone=_clone): - self.clause = clone(self.clause) - self.typeclause = clone(self.typeclause) + def _copy_internals(self, clone=_clone, **kw): + self.clause = clone(self.clause, **kw) + self.typeclause = clone(self.typeclause, **kw) def get_children(self, **kwargs): return self.clause, self.typeclause @@ -3141,8 +3146,8 @@ class _Extract(ColumnElement): self.field = field self.expr = _literal_as_binds(expr, None) - def _copy_internals(self, clone=_clone): - self.expr = clone(self.expr) + def _copy_internals(self, clone=_clone, **kw): + self.expr = clone(self.expr, **kw) def get_children(self, **kwargs): return self.expr, @@ -3170,8 +3175,8 @@ class _UnaryExpression(ColumnElement): def _from_objects(self): return self.element._from_objects - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) def get_children(self, **kwargs): return self.element, @@ -3233,9 +3238,9 @@ class _BinaryExpression(ColumnElement): def _from_objects(self): return self.left._from_objects + self.right._from_objects - def _copy_internals(self, clone=_clone): - self.left = clone(self.left) - self.right = clone(self.right) + def _copy_internals(self, clone=_clone, **kw): + self.left = clone(self.left, **kw) + self.right = clone(self.right, **kw) def get_children(self, **kwargs): return self.left, self.right @@ -3373,11 +3378,11 @@ class Join(FromClause): self.foreign_keys.update(itertools.chain( *[col.foreign_keys for col in columns])) - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): self._reset_exported() - self.left = clone(self.left) - self.right = clone(self.right) - self.onclause = clone(self.onclause) + self.left = clone(self.left, **kw) + self.right = clone(self.right, **kw) + self.onclause = clone(self.onclause, **kw) self.__folded_equivalents = None def get_children(self, **kwargs): @@ -3525,21 +3530,24 @@ class Alias(FromClause): for col in self.element.columns: col._make_proxy(self) - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): + # don't apply anything to an aliased Table + # for now. May want to drive this from + # the given **kw. + if isinstance(self.element, TableClause): + return self._reset_exported() - self.element = _clone(self.element) + self.element = clone(self.element, **kw) baseselectable = self.element while isinstance(baseselectable, Alias): baseselectable = baseselectable.element self.original = baseselectable - def get_children(self, column_collections=True, - aliased_selectables=True, **kwargs): + def get_children(self, column_collections=True, **kw): if column_collections: for c in self.c: yield c - if aliased_selectables: - yield self.element + yield self.element @property def _from_objects(self): @@ -3563,8 +3571,8 @@ class _Grouping(ColumnElement): def _label(self): return getattr(self.element, '_label', None) or self.anon_label - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) def get_children(self, **kwargs): return self.element, @@ -3615,8 +3623,8 @@ class _FromGrouping(FromClause): def get_children(self, **kwargs): return self.element, - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) @property def _from_objects(self): @@ -3662,12 +3670,12 @@ class _Over(ColumnElement): (self.func, self.partition_by, self.order_by) if c is not None] - def _copy_internals(self, clone=_clone): - self.func = clone(self.func) + def _copy_internals(self, clone=_clone, **kw): + self.func = clone(self.func, **kw) if self.partition_by is not None: - self.partition_by = clone(self.partition_by) + self.partition_by = clone(self.partition_by, **kw) if self.order_by is not None: - self.order_by = clone(self.order_by) + self.order_by = clone(self.order_by, **kw) @property def _from_objects(self): @@ -3732,8 +3740,8 @@ class _Label(ColumnElement): def get_children(self, **kwargs): return self.element, - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) @property def _from_objects(self): @@ -4244,14 +4252,14 @@ class CompoundSelect(_SelectBase): proxy.proxies = [c._annotate({'weight': i + 1}) for (i, c) in enumerate(cols)] - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): self._reset_exported() - self.selects = [clone(s) for s in self.selects] + self.selects = [clone(s, **kw) for s in self.selects] if hasattr(self, '_col_map'): del self._col_map for attr in ('_order_by_clause', '_group_by_clause'): if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr))) + setattr(self, attr, clone(getattr(self, attr), **kw)) def get_children(self, column_collections=True, **kwargs): return (column_collections and list(self.c) or []) \ @@ -4477,17 +4485,17 @@ class Select(_SelectBase): return True return False - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): self._reset_exported() - from_cloned = dict((f, clone(f)) + from_cloned = dict((f, clone(f, **kw)) for f in self._froms.union(self._correlate)) self._froms = util.OrderedSet(from_cloned[f] for f in self._froms) self._correlate = set(from_cloned[f] for f in self._correlate) - self._raw_columns = [clone(c) for c in self._raw_columns] + self._raw_columns = [clone(c, **kw) for c in self._raw_columns] for attr in '_whereclause', '_having', '_order_by_clause', \ '_group_by_clause': if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr))) + setattr(self, attr, clone(getattr(self, attr), **kw)) def get_children(self, column_collections=True, **kwargs): """return child elements as per the ClauseElement specification.""" @@ -4910,7 +4918,7 @@ class Insert(ValuesBase): else: return () - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): # TODO: coverage self.parameters = self.parameters.copy() @@ -4959,9 +4967,9 @@ class Update(ValuesBase): else: return () - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): # TODO: coverage - self._whereclause = clone(self._whereclause) + self._whereclause = clone(self._whereclause, **kw) self.parameters = self.parameters.copy() @_generative @@ -5020,9 +5028,9 @@ class Delete(UpdateBase): else: self._whereclause = _literal_as_text(whereclause) - def _copy_internals(self, clone=_clone): + def _copy_internals(self, clone=_clone, **kw): # TODO: coverage - self._whereclause = clone(self._whereclause) + self._whereclause = clone(self._whereclause, **kw) class _IdentifiedClause(Executable, ClauseElement): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 77c3e45ec..ed0afef24 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -417,6 +417,17 @@ def _deep_deannotate(element): element = clone(element) return element +def _shallow_annotate(element, annotations): + """Annotate the given ClauseElement and copy its internals so that + internal objects refer to the new annotated object. + + Basically used to apply a "dont traverse" annotation to a + selectable, without digging throughout the whole + structure wasting time. + """ + element = element._annotate(annotations) + element._copy_internals() + return element def splice_joins(left, right, stop_on=None): if left is None: @@ -639,7 +650,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): """ def __init__(self, selectable, equivalents=None, include=None, exclude=None): - self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]} + self.__traverse_options__ = {'stop_on':[selectable]} self.selectable = selectable self.include = include self.exclude = exclude diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 0c6be97d7..b94f07f58 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -19,7 +19,7 @@ use a non-visitor traversal system. For many examples of how the visit system is used, see the sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules. For an introduction to clause adaption, see -http://techspot.zzzeek.org/?p=19 . +http://techspot.zzzeek.org/2008/01/23/expression-transformations/ """ @@ -212,55 +212,51 @@ def traverse_depthfirst(obj, opts, visitors): return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) def cloned_traverse(obj, opts, visitors): - """clone the given expression structure, allowing modifications by visitors.""" + """clone the given expression structure, allowing + modifications by visitors.""" cloned = util.column_dict() + stop_on = util.column_set(opts.get('stop_on', [])) - def clone(element): - if element not in cloned: - cloned[element] = element._clone() - return cloned[element] - - obj = clone(obj) - stack = [obj] - - while stack: - t = stack.pop() - if t in cloned: - continue - t._copy_internals(clone=clone) - - meth = visitors.get(t.__visit_name__, None) - if meth: - meth(t) - - for c in t.get_children(**opts): - stack.append(c) + def clone(elem): + if elem in stop_on: + return elem + else: + if elem not in cloned: + cloned[elem] = newelem = elem._clone() + newelem._copy_internals(clone=clone) + meth = visitors.get(newelem.__visit_name__, None) + if meth: + meth(newelem) + return cloned[elem] + + if obj is not None: + obj = clone(obj) return obj + def replacement_traverse(obj, opts, replace): - """clone the given expression structure, allowing element replacement by a given replacement function.""" + """clone the given expression structure, allowing element + replacement by a given replacement function.""" cloned = util.column_dict() stop_on = util.column_set(opts.get('stop_on', [])) - def clone(element): - newelem = replace(element) - if newelem is not None: - stop_on.add(newelem) - return newelem - - if element not in cloned: - cloned[element] = element._clone() - return cloned[element] - - obj = clone(obj) - stack = [obj] - while stack: - t = stack.pop() - if t in stop_on: - continue - t._copy_internals(clone=clone) - for c in t.get_children(**opts): - stack.append(c) + def clone(elem, **kw): + if elem in stop_on or \ + 'no_replacement_traverse' in elem._annotations: + return elem + else: + newelem = replace(elem) + if newelem is not None: + stop_on.add(newelem) + return newelem + else: + if elem not in cloned: + cloned[elem] = newelem = elem._clone() + newelem._copy_internals(clone=clone, **kw) + return cloned[elem] + + if obj is not None: + obj = clone(obj, **opts) return obj |
