summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/visitors.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
-rw-r--r--lib/sqlalchemy/sql/visitors.py282
1 files changed, 139 insertions, 143 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 9888a228a..738dae9c7 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -1,138 +1,29 @@
from sqlalchemy import util
class ClauseVisitor(object):
- """Traverses and visits ``ClauseElement`` structures.
-
- Calls visit_XXX() methods for each particular
- ``ClauseElement`` subclass encountered. Traversal of a
- hierarchy of ``ClauseElements`` is achieved via the
- ``traverse()`` method, which is passed the lead
- ``ClauseElement``.
-
- By default, ``ClauseVisitor`` traverses all elements
- fully. Options can be specified at the class level via the
- ``__traverse_options__`` dictionary which will be passed
- to the ``get_children()`` method of each ``ClauseElement``;
- these options can indicate modifications to the set of
- elements returned, such as to not return column collections
- (column_collections=False) or to return Schema-level items
- (schema_visitor=True).
-
- ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
- operation, which will produce a copy of a given ``ClauseElement``
- structure while at the same time allowing ``ClauseVisitor`` subclasses
- to modify the new structure in-place.
-
- """
__traverse_options__ = {}
- def traverse_single(self, obj, **kwargs):
- """visit a single element, without traversing its child elements."""
-
+ def traverse_single(self, obj):
for v in self._iterate_visitors:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
- return meth(obj, **kwargs)
+ return meth(obj)
- traverse_chained = traverse_single
-
def iterate(self, obj):
"""traverse the given expression structure, returning an iterator of all elements."""
-
- stack = [obj]
- traversal = util.deque()
- while stack:
- t = stack.pop()
- traversal.appendleft(t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- return iter(traversal)
-
- def traverse(self, obj, clone=False):
- """traverse and visit the given expression structure.
-
- Returns the structure given, or a copy of the structure if
- clone=True.
-
- When the copy operation takes place, the before_clone() method
- will receive each element before it is copied. If the method
- returns a non-None value, the return value is taken as the
- "copied" element and traversal will not descend further.
-
- The visit_XXX() methods receive the element *after* it's been
- copied. To compare an element to another regardless of
- one element being a cloned copy of the original, the
- '_cloned_set' attribute of ClauseElement can be used for the compare,
- i.e.::
-
- original in copied._cloned_set
-
-
- """
- if clone:
- return self._cloned_traversal(obj)
- else:
- return self._non_cloned_traversal(obj)
-
- def copy_and_process(self, list_):
- """Apply cloned traversal to the given list of elements, and return the new list."""
-
- return [self._cloned_traversal(x) for x in list_]
- def before_clone(self, elem):
- """receive pre-copied elements during a cloning traversal.
-
- If the method returns a new element, the element is used
- instead of creating a simple copy of the element. Traversal
- will halt on the newly returned element if it is re-encountered.
- """
- return None
-
- def _clone_element(self, elem, stop_on, cloned):
- for v in self._iterate_visitors:
- newelem = v.before_clone(elem)
- if newelem:
- stop_on.add(newelem)
- return newelem
-
- if elem not in cloned:
- # the full traversal will only make a clone of a particular element
- # once.
- cloned[elem] = elem._clone()
- return cloned[elem]
-
- def _cloned_traversal(self, obj):
- """a recursive traversal which creates copies of elements, returning the new structure."""
-
- stop_on = self.__traverse_options__.get('stop_on', [])
- return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True)
-
- def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False):
- if elem in stop_on:
- return elem
-
- if _clone_toplevel:
- elem = self._clone_element(elem, stop_on, cloned)
- if elem in stop_on:
- return elem
-
- def clone(element):
- return self._clone_element(element, stop_on, cloned)
- elem._copy_internals(clone=clone)
+ return iterate(obj, self.__traverse_options__)
- self.traverse_single(elem)
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
- for e in elem.get_children(**self.__traverse_options__):
- if e not in stop_on:
- self._cloned_traversal_impl(e, stop_on, cloned)
- return elem
+ visitors = {}
- def _non_cloned_traversal(self, obj):
- """a non-recursive, non-cloning traversal."""
-
- for target in self.iterate(obj):
- self.traverse_single(target)
- return obj
+ for name in dir(self):
+ if name.startswith('visit_'):
+ visitors[name[6:]] = getattr(self, name)
+
+ return traverse(obj, self.__traverse_options__, visitors)
def _iterate_visitors(self):
"""iterate through this visitor and each 'chained' visitor."""
@@ -152,31 +43,136 @@ class ClauseVisitor(object):
tail._next = visitor
return self
-class NoColumnVisitor(ClauseVisitor):
- """ClauseVisitor with 'column_collections' set to False; will not
- traverse the front-facing Column collections on Table, Alias, Select,
- and CompoundSelect objects.
+class CloningVisitor(ClauseVisitor):
+ def copy_and_process(self, list_):
+ """Apply cloned traversal to the given list of elements, and return the new list."""
+
+ return [self.traverse(x) for x in list_]
+
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
+
+ visitors = {}
+
+ for name in dir(self):
+ if name.startswith('visit_'):
+ visitors[name[6:]] = getattr(self, name)
+
+ return cloned_traverse(obj, self.__traverse_options__, visitors)
+
+class ReplacingCloningVisitor(CloningVisitor):
+ def replace(self, elem):
+ """receive pre-copied elements during a cloning traversal.
+
+ If the method returns a new element, the element is used
+ instead of creating a simple copy of the element. Traversal
+ will halt on the newly returned element if it is re-encountered.
+ """
+ return None
+
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
+
+ def replace(elem):
+ for v in self._iterate_visitors:
+ e = v.replace(elem)
+ if e:
+ return e
+ return replacement_traverse(obj, self.__traverse_options__, replace)
+
+def iterate(obj, opts):
+ """traverse the given expression structure, returning an iterator.
+
+ traversal is configured to be breadth-first.
"""
+ stack = util.deque([obj])
+ while stack:
+ t = stack.popleft()
+ yield t
+ for c in t.get_children(**opts):
+ stack.append(c)
+
+def iterate_depthfirst(obj, opts):
+ """traverse the given expression structure, returning an iterator.
- __traverse_options__ = {'column_collections':False}
-
-class NullVisitor(ClauseVisitor):
- def traverse(self, obj, clone=False):
- next = getattr(self, '_next', None)
- if next:
- return next.traverse(obj, clone=clone)
- else:
- return obj
-
-def traverse(clause, **kwargs):
- """traverse the given clause, applying visit functions passed in as keyword arguments."""
+ traversal is configured to be depth-first.
+
+ """
+ stack = util.deque([obj])
+ traversal = util.deque()
+ while stack:
+ t = stack.pop()
+ traversal.appendleft(t)
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return iter(traversal)
+
+def traverse_using(iterator, obj, visitors):
+ """visit the given expression structure using the given iterator of objects."""
+
+ for target in iterator:
+ meth = visitors.get(target.__visit_name__, None)
+ if meth:
+ meth(target)
+ return obj
- clone = kwargs.pop('clone', False)
- class Vis(ClauseVisitor):
- __traverse_options__ = kwargs.pop('traverse_options', {})
- vis = Vis()
- for key in kwargs:
- setattr(vis, key, kwargs[key])
- return vis.traverse(clause, clone=clone)
+def traverse(obj, opts, visitors):
+ """traverse and visit the given expression structure using the default iterator."""
+
+ return traverse_using(iterate(obj, opts), obj, visitors)
+
+def traverse_depthfirst(obj, opts, visitors):
+ """traverse and visit the given expression structure using the depth-first iterator."""
+
+ return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
+
+def cloned_traverse(obj, opts, visitors):
+ cloned = {}
+
+ def clone(element):
+ if element not in cloned:
+ cloned[element] = element._clone()
+ return cloned[element]
+
+ obj = clone(obj)
+ stack = [obj]
+
+ while stack:
+ t = stack.pop()
+ if t in cloned:
+ continue
+ t._copy_internals(clone=clone)
+
+ meth = visitors.get(t.__visit_name__, None)
+ if meth:
+ meth(t)
+
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return obj
+
+def replacement_traverse(obj, opts, replace):
+ cloned = {}
+ stop_on = util.Set(opts.get('stop_on', []))
+
+ def clone(element):
+ newelem = replace(element)
+ if newelem:
+ stop_on.add(newelem)
+ return newelem
+
+ if element not in cloned:
+ cloned[element] = element._clone()
+ return cloned[element]
+ obj = clone(obj)
+ stack = [obj]
+ while stack:
+ t = stack.pop()
+ if t in stop_on:
+ continue
+ t._copy_internals(clone=clone)
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return obj