summaryrefslogtreecommitdiff
path: root/taskflow/engines/action_engine/scopes.py
diff options
context:
space:
mode:
Diffstat (limited to 'taskflow/engines/action_engine/scopes.py')
-rw-r--r--taskflow/engines/action_engine/scopes.py39
1 files changed, 21 insertions, 18 deletions
diff --git a/taskflow/engines/action_engine/scopes.py b/taskflow/engines/action_engine/scopes.py
index 5fd7ee6..1d309d8 100644
--- a/taskflow/engines/action_engine/scopes.py
+++ b/taskflow/engines/action_engine/scopes.py
@@ -14,14 +14,14 @@
# License for the specific language governing permissions and limitations
# under the License.
-from taskflow import atom as atom_type
-from taskflow import flow as flow_type
+from taskflow.engines.action_engine import compiler as co
from taskflow import logging
LOG = logging.getLogger(__name__)
-def _extract_atoms_iter(node, idx=-1):
+def _depth_first_reverse_iterate(node, idx=-1):
+ """Iterates connected (in reverse) nodes in tree (from starting node)."""
# Always go left to right, since right to left is the pattern order
# and we want to go backwards and not forwards through that ordering...
if idx == -1:
@@ -29,15 +29,17 @@ def _extract_atoms_iter(node, idx=-1):
else:
children_iter = reversed(node[0:idx])
for child in children_iter:
- if isinstance(child.item, flow_type.Flow):
- for atom in _extract_atoms_iter(child):
+ child_kind = child.metadata['kind']
+ if child_kind == co.FLOW:
+ # Jump through these...
+ #
+ # TODO(harlowja): make this non-recursive and remove this
+ # style of doing this when
+ # https://review.openstack.org/#/c/205731/ merges...
+ for atom in _depth_first_reverse_iterate(child):
yield atom
- elif isinstance(child.item, atom_type.Atom):
- yield child.item
else:
- raise TypeError(
- "Unknown extraction item '%s' (%s)" % (child.item,
- type(child.item)))
+ yield child.item
class ScopeWalker(object):
@@ -57,13 +59,10 @@ class ScopeWalker(object):
" hierarchy" % atom)
self._level_cache = {}
self._atom = atom
- self._graph = compilation.execution_graph
+ self._execution_graph = compilation.execution_graph
self._names_only = names_only
self._predecessors = None
- #: Function that extracts the *associated* atoms of a given tree node.
- _extract_atoms_iter = staticmethod(_extract_atoms_iter)
-
def __iter__(self):
"""Iterates over the visible scopes.
@@ -99,10 +98,14 @@ class ScopeWalker(object):
nodes (aka we have reached the top of the tree) or we run out of
predecessors.
"""
+ graph = self._execution_graph
if self._predecessors is None:
- pred_iter = self._graph.bfs_predecessors_iter(self._atom)
- self._predecessors = set(pred_iter)
- predecessors = self._predecessors.copy()
+ predecessors = set(
+ node for node in graph.bfs_predecessors_iter(self._atom)
+ if graph.node[node]['kind'] in co.ATOMS)
+ self._predecessors = predecessors.copy()
+ else:
+ predecessors = self._predecessors.copy()
last = self._node
for lvl, parent in enumerate(self._node.path_iter(include_self=False)):
if not predecessors:
@@ -114,7 +117,7 @@ class ScopeWalker(object):
except KeyError:
visible = []
removals = set()
- for atom in self._extract_atoms_iter(parent, idx=last_idx):
+ for atom in _depth_first_reverse_iterate(parent, idx=last_idx):
if atom in predecessors:
predecessors.remove(atom)
removals.add(atom)