summaryrefslogtreecommitdiff
path: root/taskflow/patterns/graph_flow.py
diff options
context:
space:
mode:
Diffstat (limited to 'taskflow/patterns/graph_flow.py')
-rw-r--r--taskflow/patterns/graph_flow.py25
1 files changed, 21 insertions, 4 deletions
diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py
index 7db4fee..9717b21 100644
--- a/taskflow/patterns/graph_flow.py
+++ b/taskflow/patterns/graph_flow.py
@@ -16,8 +16,6 @@
import collections
-from networkx.algorithms import traversal
-
from taskflow import exceptions as exc
from taskflow import flow
from taskflow.types import graph as gr
@@ -170,6 +168,26 @@ class Flow(flow.Flow):
for (u, v, e_data) in self._get_subgraph().edges_iter(data=True):
yield (u, v, e_data)
+ @property
+ def requires(self):
+ requires = set()
+ retry_provides = set()
+ if self._retry is not None:
+ requires.update(self._retry.requires)
+ retry_provides.update(self._retry.provides)
+ g = self._get_subgraph()
+ for item in g.nodes_iter():
+ item_requires = item.requires - retry_provides
+ # Now scan predecessors to see if they provide what we want.
+ if item_requires:
+ for pred_item in g.bfs_predecessors_iter(item):
+ item_requires = item_requires - pred_item.provides
+ if not item_requires:
+ break
+ if item_requires:
+ requires.update(item_requires)
+ return frozenset(requires)
+
class TargetedFlow(Flow):
"""Graph flow with a target.
@@ -223,8 +241,7 @@ class TargetedFlow(Flow):
if self._target is None:
return self._graph
nodes = [self._target]
- nodes.extend(dst for _src, dst in
- traversal.dfs_edges(self._graph.reverse(), self._target))
+ nodes.extend(self._graph.bfs_predecessors_iter(self._target))
self._subgraph = self._graph.subgraph(nodes)
self._subgraph.freeze()
return self._subgraph