summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/properties.py44
-rw-r--r--lib/sqlalchemy/orm/query.py9
-rw-r--r--lib/sqlalchemy/sql/expression.py116
-rw-r--r--lib/sqlalchemy/sql/util.py13
-rw-r--r--lib/sqlalchemy/sql/visitors.py80
5 files changed, 136 insertions, 126 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 494d94bb0..4de438e55 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -13,7 +13,7 @@ mapped attributes.
from sqlalchemy import sql, util, log, exc as sa_exc
from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
- join_condition
+ join_condition, _shallow_annotate
from sqlalchemy.sql import operators, expression
from sqlalchemy.orm import attributes, dependency, mapper, \
object_mapper, strategies, configure_mappers
@@ -167,9 +167,6 @@ class ColumnProperty(StrategizedProperty):
log.class_logger(ColumnProperty)
-
-
-
class RelationshipProperty(StrategizedProperty):
"""Describes an object property that holds a single item or list
of items that correspond to a related database table.
@@ -448,7 +445,7 @@ class RelationshipProperty(StrategizedProperty):
# should not correlate or otherwise reach out
# to anything in the enclosing query.
if criterion is not None:
- criterion = criterion._annotate({'_halt_adapt': True})
+ criterion = criterion._annotate({'no_replacement_traverse': True})
crit = j & criterion
@@ -1485,6 +1482,14 @@ class RelationshipProperty(StrategizedProperty):
else:
aliased = True
+ # place a barrier on the destination such that
+ # replacement traversals won't ever dig into it.
+ # its internal structure remains fixed
+ # regardless of context.
+ dest_selectable = _shallow_annotate(
+ dest_selectable,
+ {'no_replacement_traverse':True})
+
aliased = aliased or (source_selectable is not None)
primaryjoin, secondaryjoin, secondary = self.primaryjoin, \
@@ -1508,13 +1513,10 @@ class RelationshipProperty(StrategizedProperty):
if secondary is not None:
secondary = secondary.alias()
primary_aliasizer = ClauseAdapter(secondary)
- if dest_selectable is not None:
- secondary_aliasizer = \
- ClauseAdapter(dest_selectable,
- equivalents=self.mapper._equivalent_columns).\
- chain(primary_aliasizer)
- else:
- secondary_aliasizer = primary_aliasizer
+ secondary_aliasizer = \
+ ClauseAdapter(dest_selectable,
+ equivalents=self.mapper._equivalent_columns).\
+ chain(primary_aliasizer)
if source_selectable is not None:
primary_aliasizer = \
ClauseAdapter(secondary).\
@@ -1523,20 +1525,14 @@ class RelationshipProperty(StrategizedProperty):
secondaryjoin = \
secondary_aliasizer.traverse(secondaryjoin)
else:
- if dest_selectable is not None:
- primary_aliasizer = ClauseAdapter(dest_selectable,
- exclude=self.local_side,
- equivalents=self.mapper._equivalent_columns)
- if source_selectable is not None:
- primary_aliasizer.chain(
- ClauseAdapter(source_selectable,
- exclude=self.remote_side,
- equivalents=self.parent._equivalent_columns))
- elif source_selectable is not None:
- primary_aliasizer = \
+ primary_aliasizer = ClauseAdapter(dest_selectable,
+ exclude=self.local_side,
+ equivalents=self.mapper._equivalent_columns)
+ if source_selectable is not None:
+ primary_aliasizer.chain(
ClauseAdapter(source_selectable,
exclude=self.remote_side,
- equivalents=self.parent._equivalent_columns)
+ equivalents=self.parent._equivalent_columns))
secondary_aliasizer = None
primaryjoin = primary_aliasizer.traverse(primaryjoin)
target_adapter = secondary_aliasizer or primary_aliasizer
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 8d64d69b4..a3b13abb2 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -251,9 +251,6 @@ class Query(object):
return clause
def replace(elem):
- if '_halt_adapt' in elem._annotations:
- return elem
-
for _orm_only, adapter in adapters:
# if 'orm only', look for ORM annotations
# in the element before adapting.
@@ -267,7 +264,7 @@ class Query(object):
return visitors.replacement_traverse(
clause,
- {'column_collections':False},
+ {},
replace
)
@@ -438,7 +435,9 @@ class Query(object):
statement
if self._params:
stmt = stmt.params(self._params)
- return stmt._annotate({'_halt_adapt': True})
+ # TODO: there's no tests covering effects of
+ # the annotation not being there
+ return stmt._annotate({'no_replacement_traverse': True})
def subquery(self, name=None):
"""return the full SELECT statement represented by this :class:`.Query`,
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 071bb3c50..fa0586e2d 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1216,7 +1216,7 @@ def _string_or_unprintable(element):
except:
return "unprintable element %r" % element
-def _clone(element):
+def _clone(element, **kw):
return element._clone()
def _expand_cloned(elements):
@@ -1522,12 +1522,16 @@ class ClauseElement(Visitable):
"""
return self is other
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
shallow-copied elements to create a deep copy.
+ The given clone function should be used, which may be applying
+ additional transformations to the element (i.e. replacement
+ traversal, cloned traversal, annotations).
+
"""
pass
@@ -2755,8 +2759,8 @@ class _TextClause(Executable, ClauseElement):
else:
return self
- def _copy_internals(self, clone=_clone):
- self.bindparams = dict((b.key, clone(b))
+ def _copy_internals(self, clone=_clone, **kw):
+ self.bindparams = dict((b.key, clone(b, **kw))
for b in self.bindparams.values())
def get_children(self, **kwargs):
@@ -2846,8 +2850,8 @@ class ClauseList(ClauseElement):
else:
self.clauses.append(_literal_as_text(clause))
- def _copy_internals(self, clone=_clone):
- self.clauses = [clone(clause) for clause in self.clauses]
+ def _copy_internals(self, clone=_clone, **kw):
+ self.clauses = [clone(clause, **kw) for clause in self.clauses]
def get_children(self, **kwargs):
return self.clauses
@@ -2947,12 +2951,13 @@ class _Case(ColumnElement):
else:
self.else_ = None
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
if self.value is not None:
- self.value = clone(self.value)
- self.whens = [(clone(x), clone(y)) for x, y in self.whens]
+ self.value = clone(self.value, **kw)
+ self.whens = [(clone(x, **kw), clone(y, **kw))
+ for x, y in self.whens]
if self.else_ is not None:
- self.else_ = clone(self.else_)
+ self.else_ = clone(self.else_, **kw)
def get_children(self, **kwargs):
if self.value is not None:
@@ -3028,8 +3033,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
def get_children(self, **kwargs):
return self.clause_expr,
- def _copy_internals(self, clone=_clone):
- self.clause_expr = clone(self.clause_expr)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.clause_expr = clone(self.clause_expr, **kw)
self._reset_exported()
util.reset_memoized(self, 'clauses')
@@ -3120,9 +3125,9 @@ class _Cast(ColumnElement):
self.clause = _literal_as_binds(clause, None)
self.typeclause = _TypeClause(self.type)
- def _copy_internals(self, clone=_clone):
- self.clause = clone(self.clause)
- self.typeclause = clone(self.typeclause)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.clause = clone(self.clause, **kw)
+ self.typeclause = clone(self.typeclause, **kw)
def get_children(self, **kwargs):
return self.clause, self.typeclause
@@ -3141,8 +3146,8 @@ class _Extract(ColumnElement):
self.field = field
self.expr = _literal_as_binds(expr, None)
- def _copy_internals(self, clone=_clone):
- self.expr = clone(self.expr)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.expr = clone(self.expr, **kw)
def get_children(self, **kwargs):
return self.expr,
@@ -3170,8 +3175,8 @@ class _UnaryExpression(ColumnElement):
def _from_objects(self):
return self.element._from_objects
- def _copy_internals(self, clone=_clone):
- self.element = clone(self.element)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
return self.element,
@@ -3233,9 +3238,9 @@ class _BinaryExpression(ColumnElement):
def _from_objects(self):
return self.left._from_objects + self.right._from_objects
- def _copy_internals(self, clone=_clone):
- self.left = clone(self.left)
- self.right = clone(self.right)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.left = clone(self.left, **kw)
+ self.right = clone(self.right, **kw)
def get_children(self, **kwargs):
return self.left, self.right
@@ -3373,11 +3378,11 @@ class Join(FromClause):
self.foreign_keys.update(itertools.chain(
*[col.foreign_keys for col in columns]))
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
self._reset_exported()
- self.left = clone(self.left)
- self.right = clone(self.right)
- self.onclause = clone(self.onclause)
+ self.left = clone(self.left, **kw)
+ self.right = clone(self.right, **kw)
+ self.onclause = clone(self.onclause, **kw)
self.__folded_equivalents = None
def get_children(self, **kwargs):
@@ -3525,21 +3530,24 @@ class Alias(FromClause):
for col in self.element.columns:
col._make_proxy(self)
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
+ # don't apply anything to an aliased Table
+ # for now. May want to drive this from
+ # the given **kw.
+ if isinstance(self.element, TableClause):
+ return
self._reset_exported()
- self.element = _clone(self.element)
+ self.element = clone(self.element, **kw)
baseselectable = self.element
while isinstance(baseselectable, Alias):
baseselectable = baseselectable.element
self.original = baseselectable
- def get_children(self, column_collections=True,
- aliased_selectables=True, **kwargs):
+ def get_children(self, column_collections=True, **kw):
if column_collections:
for c in self.c:
yield c
- if aliased_selectables:
- yield self.element
+ yield self.element
@property
def _from_objects(self):
@@ -3563,8 +3571,8 @@ class _Grouping(ColumnElement):
def _label(self):
return getattr(self.element, '_label', None) or self.anon_label
- def _copy_internals(self, clone=_clone):
- self.element = clone(self.element)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
return self.element,
@@ -3615,8 +3623,8 @@ class _FromGrouping(FromClause):
def get_children(self, **kwargs):
return self.element,
- def _copy_internals(self, clone=_clone):
- self.element = clone(self.element)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.element = clone(self.element, **kw)
@property
def _from_objects(self):
@@ -3662,12 +3670,12 @@ class _Over(ColumnElement):
(self.func, self.partition_by, self.order_by)
if c is not None]
- def _copy_internals(self, clone=_clone):
- self.func = clone(self.func)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.func = clone(self.func, **kw)
if self.partition_by is not None:
- self.partition_by = clone(self.partition_by)
+ self.partition_by = clone(self.partition_by, **kw)
if self.order_by is not None:
- self.order_by = clone(self.order_by)
+ self.order_by = clone(self.order_by, **kw)
@property
def _from_objects(self):
@@ -3732,8 +3740,8 @@ class _Label(ColumnElement):
def get_children(self, **kwargs):
return self.element,
- def _copy_internals(self, clone=_clone):
- self.element = clone(self.element)
+ def _copy_internals(self, clone=_clone, **kw):
+ self.element = clone(self.element, **kw)
@property
def _from_objects(self):
@@ -4244,14 +4252,14 @@ class CompoundSelect(_SelectBase):
proxy.proxies = [c._annotate({'weight': i + 1}) for (i,
c) in enumerate(cols)]
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
self._reset_exported()
- self.selects = [clone(s) for s in self.selects]
+ self.selects = [clone(s, **kw) for s in self.selects]
if hasattr(self, '_col_map'):
del self._col_map
for attr in ('_order_by_clause', '_group_by_clause'):
if getattr(self, attr) is not None:
- setattr(self, attr, clone(getattr(self, attr)))
+ setattr(self, attr, clone(getattr(self, attr), **kw))
def get_children(self, column_collections=True, **kwargs):
return (column_collections and list(self.c) or []) \
@@ -4477,17 +4485,17 @@ class Select(_SelectBase):
return True
return False
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
self._reset_exported()
- from_cloned = dict((f, clone(f))
+ from_cloned = dict((f, clone(f, **kw))
for f in self._froms.union(self._correlate))
self._froms = util.OrderedSet(from_cloned[f] for f in self._froms)
self._correlate = set(from_cloned[f] for f in self._correlate)
- self._raw_columns = [clone(c) for c in self._raw_columns]
+ self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
for attr in '_whereclause', '_having', '_order_by_clause', \
'_group_by_clause':
if getattr(self, attr) is not None:
- setattr(self, attr, clone(getattr(self, attr)))
+ setattr(self, attr, clone(getattr(self, attr), **kw))
def get_children(self, column_collections=True, **kwargs):
"""return child elements as per the ClauseElement specification."""
@@ -4910,7 +4918,7 @@ class Insert(ValuesBase):
else:
return ()
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
self.parameters = self.parameters.copy()
@@ -4959,9 +4967,9 @@ class Update(ValuesBase):
else:
return ()
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
- self._whereclause = clone(self._whereclause)
+ self._whereclause = clone(self._whereclause, **kw)
self.parameters = self.parameters.copy()
@_generative
@@ -5020,9 +5028,9 @@ class Delete(UpdateBase):
else:
self._whereclause = _literal_as_text(whereclause)
- def _copy_internals(self, clone=_clone):
+ def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
- self._whereclause = clone(self._whereclause)
+ self._whereclause = clone(self._whereclause, **kw)
class _IdentifiedClause(Executable, ClauseElement):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 77c3e45ec..ed0afef24 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -417,6 +417,17 @@ def _deep_deannotate(element):
element = clone(element)
return element
+def _shallow_annotate(element, annotations):
+ """Annotate the given ClauseElement and copy its internals so that
+ internal objects refer to the new annotated object.
+
+ Basically used to apply a "dont traverse" annotation to a
+ selectable, without digging throughout the whole
+ structure wasting time.
+ """
+ element = element._annotate(annotations)
+ element._copy_internals()
+ return element
def splice_joins(left, right, stop_on=None):
if left is None:
@@ -639,7 +650,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""
def __init__(self, selectable, equivalents=None, include=None, exclude=None):
- self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]}
+ self.__traverse_options__ = {'stop_on':[selectable]}
self.selectable = selectable
self.include = include
self.exclude = exclude
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 0c6be97d7..b94f07f58 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -19,7 +19,7 @@ use a non-visitor traversal system.
For many examples of how the visit system is used, see the
sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules.
For an introduction to clause adaption, see
-http://techspot.zzzeek.org/?p=19 .
+http://techspot.zzzeek.org/2008/01/23/expression-transformations/
"""
@@ -212,55 +212,51 @@ def traverse_depthfirst(obj, opts, visitors):
return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
def cloned_traverse(obj, opts, visitors):
- """clone the given expression structure, allowing modifications by visitors."""
+ """clone the given expression structure, allowing
+ modifications by visitors."""
cloned = util.column_dict()
+ stop_on = util.column_set(opts.get('stop_on', []))
- def clone(element):
- if element not in cloned:
- cloned[element] = element._clone()
- return cloned[element]
-
- obj = clone(obj)
- stack = [obj]
-
- while stack:
- t = stack.pop()
- if t in cloned:
- continue
- t._copy_internals(clone=clone)
-
- meth = visitors.get(t.__visit_name__, None)
- if meth:
- meth(t)
-
- for c in t.get_children(**opts):
- stack.append(c)
+ def clone(elem):
+ if elem in stop_on:
+ return elem
+ else:
+ if elem not in cloned:
+ cloned[elem] = newelem = elem._clone()
+ newelem._copy_internals(clone=clone)
+ meth = visitors.get(newelem.__visit_name__, None)
+ if meth:
+ meth(newelem)
+ return cloned[elem]
+
+ if obj is not None:
+ obj = clone(obj)
return obj
+
def replacement_traverse(obj, opts, replace):
- """clone the given expression structure, allowing element replacement by a given replacement function."""
+ """clone the given expression structure, allowing element
+ replacement by a given replacement function."""
cloned = util.column_dict()
stop_on = util.column_set(opts.get('stop_on', []))
- def clone(element):
- newelem = replace(element)
- if newelem is not None:
- stop_on.add(newelem)
- return newelem
-
- if element not in cloned:
- cloned[element] = element._clone()
- return cloned[element]
-
- obj = clone(obj)
- stack = [obj]
- while stack:
- t = stack.pop()
- if t in stop_on:
- continue
- t._copy_internals(clone=clone)
- for c in t.get_children(**opts):
- stack.append(c)
+ def clone(elem, **kw):
+ if elem in stop_on or \
+ 'no_replacement_traverse' in elem._annotations:
+ return elem
+ else:
+ newelem = replace(elem)
+ if newelem is not None:
+ stop_on.add(newelem)
+ return newelem
+ else:
+ if elem not in cloned:
+ cloned[elem] = newelem = elem._clone()
+ newelem._copy_internals(clone=clone, **kw)
+ return cloned[elem]
+
+ if obj is not None:
+ obj = clone(obj, **opts)
return obj