diff options
Diffstat (limited to 'taskflow/patterns/graph_flow.py')
| -rw-r--r-- | taskflow/patterns/graph_flow.py | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 7db4fee..9717b21 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -16,8 +16,6 @@ import collections -from networkx.algorithms import traversal - from taskflow import exceptions as exc from taskflow import flow from taskflow.types import graph as gr @@ -170,6 +168,26 @@ class Flow(flow.Flow): for (u, v, e_data) in self._get_subgraph().edges_iter(data=True): yield (u, v, e_data) + @property + def requires(self): + requires = set() + retry_provides = set() + if self._retry is not None: + requires.update(self._retry.requires) + retry_provides.update(self._retry.provides) + g = self._get_subgraph() + for item in g.nodes_iter(): + item_requires = item.requires - retry_provides + # Now scan predecessors to see if they provide what we want. + if item_requires: + for pred_item in g.bfs_predecessors_iter(item): + item_requires = item_requires - pred_item.provides + if not item_requires: + break + if item_requires: + requires.update(item_requires) + return frozenset(requires) + class TargetedFlow(Flow): """Graph flow with a target. @@ -223,8 +241,7 @@ class TargetedFlow(Flow): if self._target is None: return self._graph nodes = [self._target] - nodes.extend(dst for _src, dst in - traversal.dfs_edges(self._graph.reverse(), self._target)) + nodes.extend(self._graph.bfs_predecessors_iter(self._target)) self._subgraph = self._graph.subgraph(nodes) self._subgraph.freeze() return self._subgraph |
