1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
class ClauseVisitor(object):
"""Traverses and visits ``ClauseElement`` structures.
Calls visit_XXX() methods dynamically generated for each particular
``ClauseElement`` subclass encountered. Traversal of a
hierarchy of ``ClauseElements`` is achieved via the
``traverse()`` method, which is passed the lead
``ClauseElement``.
By default, ``ClauseVisitor`` traverses all elements
fully. Options can be specified at the class level via the
``__traverse_options__`` dictionary which will be passed
to the ``get_children()`` method of each ``ClauseElement``;
these options can indicate modifications to the set of
elements returned, such as to not return column collections
(column_collections=False) or to return Schema-level items
(schema_visitor=True).
``ClauseVisitor`` also supports a simultaneous copy-and-traverse
operation, which will produce a copy of a given ``ClauseElement``
structure while at the same time allowing ``ClauseVisitor`` subclasses
to modify the new structure in-place.
"""
__traverse_options__ = {}
def traverse_single(self, obj, **kwargs):
meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
if meth:
return meth(obj, **kwargs)
def traverse_chained(self, obj, **kwargs):
v = self
while v is not None:
meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
if meth:
meth(obj, **kwargs)
v = getattr(v, '_next', None)
def iterate(self, obj):
stack = [obj]
traversal = []
while len(stack) > 0:
t = stack.pop()
yield t
traversal.insert(0, t)
for c in t.get_children(**self.__traverse_options__):
stack.append(c)
def traverse(self, obj, clone=False):
if clone:
cloned = {}
def do_clone(obj):
# the full traversal will only make a clone of a particular element
# once.
if obj not in cloned:
cloned[obj] = obj._clone()
return cloned[obj]
obj = do_clone(obj)
stack = [obj]
traversal = []
while len(stack) > 0:
t = stack.pop()
traversal.insert(0, t)
if clone:
t._copy_internals(clone=do_clone)
for c in t.get_children(**self.__traverse_options__):
stack.append(c)
for target in traversal:
v = self
while v is not None:
meth = getattr(v, "visit_%s" % target.__visit_name__, None)
if meth:
meth(target)
v = getattr(v, '_next', None)
return obj
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
the chained visitor will receive all visit events after this one."""
tail = self
while getattr(tail, '_next', None) is not None:
tail = tail._next
tail._next = visitor
return self
class NoColumnVisitor(ClauseVisitor):
"""ClauseVisitor with 'column_collections' set to False; will not
traverse the front-facing Column collections on Table, Alias, Select,
and CompoundSelect objects.
"""
__traverse_options__ = {'column_collections':False}
def traverse(clause, **kwargs):
clone = kwargs.pop('clone', False)
class Vis(ClauseVisitor):
__traverse_options__ = kwargs.pop('traverse_options', {})
def __getattr__(self, key):
if key in kwargs:
return kwargs[key]
else:
return None
return Vis().traverse(clause, clone=clone)
|