summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichal Arbet <michal.arbet@ultimum.io>2018-06-25 16:06:18 +0200
committerMichal Arbet <michal.arbet@ultimum.io>2018-07-11 13:11:51 +0200
commitd985c5a256bcf8f67e80235ba387599f448f2c49 (patch)
treeb7f5f712a08a83090aefb43d4b0ce121fc6d2fca
parent19d0eb89c80655da766b6fc55b3ebf6b5a0d52f8 (diff)
downloadtaskflow-d985c5a256bcf8f67e80235ba387599f448f2c49.tar.gz
Fix code to support networkx > 1.0
With the release of NetworkX 2.0 the reporting API was moved to view/iterator model. Many methods were moved from reporting lists or dicts to iterating over the information. Methods that used to return containers now return views and methods that returned iterators have been removed in networkx. Because of this change in NetworkX 2.0 , taskflow code have to be changed also to support networkx > 2.0 Change-Id: I23c226f37bd85c1e38039fbcb302a2d0de49f333 Closes-Bug: #1778115
-rw-r--r--doc/requirements.txt2
-rw-r--r--requirements.txt2
-rw-r--r--setup.cfg2
-rw-r--r--taskflow/patterns/graph_flow.py2
-rw-r--r--taskflow/tests/unit/action_engine/test_compile.py4
-rw-r--r--taskflow/tests/unit/test_types.py12
-rw-r--r--taskflow/types/graph.py163
-rw-r--r--taskflow/types/tree.py2
-rw-r--r--taskflow/utils/misc.py5
-rw-r--r--test-requirements.txt3
10 files changed, 177 insertions, 20 deletions
diff --git a/doc/requirements.txt b/doc/requirements.txt
index 741c0e9..7229b2a 100644
--- a/doc/requirements.txt
+++ b/doc/requirements.txt
@@ -16,7 +16,7 @@ kazoo>=2.2 # Apache-2.0
zake>=0.1.6 # Apache-2.0
redis>=2.10.0 # MIT
-eventlet!=0.18.3,!=0.20.1,!=0.21.0 # MIT
+eventlet!=0.18.3,!=0.20.1,!=0.21.0,>=0.18.2 # MIT
SQLAlchemy!=1.1.5,!=1.1.6,!=1.1.7,!=1.1.8,>=1.0.10 # MIT
alembic>=0.8.10 # MIT
SQLAlchemy-Utils>=0.30.11 # BSD License
diff --git a/requirements.txt b/requirements.txt
index 2ab5ad9..b83e8ae 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,7 +23,7 @@ fasteners>=0.7.0 # Apache-2.0
networkx>=1.10 # BSD
# For contextlib new additions/compatibility for <= python 3.3
-contextlib2>=0.4.0 # PSF License
+contextlib2>=0.4.0;python_version<'3.0' # PSF License
# Used for backend storage engine loading.
stevedore>=1.20.0 # Apache-2.0
diff --git a/setup.cfg b/setup.cfg
index 6c91c51..dd63094 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -66,7 +66,7 @@ redis =
workers =
kombu!=4.0.2,>=4.0.0 # BSD
eventlet =
- eventlet!=0.18.3,!=0.20.1,!=0.21.0 # MIT
+ eventlet!=0.18.3,!=0.20.1,!=0.21.0,>=0.18.2 # MIT
database =
SQLAlchemy!=1.1.5,!=1.1.6,!=1.1.7,!=1.1.8,>=1.0.10 # MIT
alembic>=0.8.10 # MIT
diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py
index 8296a6d..ea45f85 100644
--- a/taskflow/patterns/graph_flow.py
+++ b/taskflow/patterns/graph_flow.py
@@ -367,6 +367,6 @@ class TargetedFlow(Flow):
return self._graph
nodes = [self._target]
nodes.extend(self._graph.bfs_predecessors_iter(self._target))
- self._subgraph = self._graph.subgraph(nodes)
+ self._subgraph = gr.DiGraph(data=self._graph.subgraph(nodes))
self._subgraph.freeze()
return self._subgraph
diff --git a/taskflow/tests/unit/action_engine/test_compile.py b/taskflow/tests/unit/action_engine/test_compile.py
index 1a310d1..fb99439 100644
--- a/taskflow/tests/unit/action_engine/test_compile.py
+++ b/taskflow/tests/unit/action_engine/test_compile.py
@@ -73,7 +73,7 @@ class PatternCompileTest(test.TestCase):
compiler.PatternCompiler(flo).compile())
self.assertEqual(8, len(g))
- order = g.topological_sort()
+ order = list(g.topological_sort())
self.assertEqual(['test', 'a', 'b', 'c',
"sub-test", 'd', "sub-test[$]",
'test[$]'], order)
@@ -430,7 +430,7 @@ class PatternCompileTest(test.TestCase):
self.assertTrue(g.has_edge(flow, a))
self.assertTrue(g.has_edge(a, empty_flow))
- empty_flow_successors = g.successors(empty_flow)
+ empty_flow_successors = list(g.successors(empty_flow))
self.assertEqual(1, len(empty_flow_successors))
empty_flow_terminal = empty_flow_successors[0]
self.assertIs(empty_flow, empty_flow_terminal.flow)
diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py
index ba9d8fe..f1e84d1 100644
--- a/taskflow/tests/unit/test_types.py
+++ b/taskflow/tests/unit/test_types.py
@@ -576,12 +576,12 @@ CEO
self.assertEqual(root.child_count(only_direct=False) + 1, len(g))
for node in root.dfs_iter(include_self=True):
self.assertIn(node.item, g)
- self.assertEqual([], g.predecessors('animal'))
- self.assertEqual(['animal'], g.predecessors('reptile'))
- self.assertEqual(['primate'], g.predecessors('human'))
- self.assertEqual(['mammal'], g.predecessors('primate'))
- self.assertEqual(['animal'], g.predecessors('mammal'))
- self.assertEqual(['mammal', 'reptile'], g.successors('animal'))
+ self.assertEqual([], list(g.predecessors('animal')))
+ self.assertEqual(['animal'], list(g.predecessors('reptile')))
+ self.assertEqual(['primate'], list(g.predecessors('human')))
+ self.assertEqual(['mammal'], list(g.predecessors('primate')))
+ self.assertEqual(['animal'], list(g.predecessors('mammal')))
+ self.assertEqual(['mammal', 'reptile'], list(g.successors('animal')))
def test_to_digraph_retains_metadata(self):
root = tree.Node("chickens", alive=True)
diff --git a/taskflow/types/graph.py b/taskflow/types/graph.py
index aebbf7b..553e690 100644
--- a/taskflow/types/graph.py
+++ b/taskflow/types/graph.py
@@ -21,6 +21,8 @@ import networkx as nx
from networkx.drawing import nx_pydot
import six
+from taskflow.utils import misc
+
def _common_format(g, edge_notation):
lines = []
@@ -47,7 +49,10 @@ class Graph(nx.Graph):
"""A graph subclass with useful utility functions."""
def __init__(self, data=None, name=''):
- super(Graph, self).__init__(name=name, data=data)
+ if misc.nx_version() == '1':
+ super(Graph, self).__init__(name=name, data=data)
+ else:
+ super(Graph, self).__init__(name=name, incoming_graph_data=data)
self.frozen = False
def freeze(self):
@@ -64,12 +69,67 @@ class Graph(nx.Graph):
"""Pretty formats your graph into a string."""
return os.linesep.join(_common_format(self, "<->"))
+ def nodes_iter(self, data=False):
+ """Returns an iterable object over the nodes.
+
+ Type of iterable returned object depends on which version
+ of networkx is used. When networkx < 2.0 is used , method
+ returns an iterator, but if networkx > 2.0 is used, it returns
+ NodeView of the Graph which is also iterable.
+ """
+ if misc.nx_version() == '1':
+ return super(Graph, self).nodes_iter(data=data)
+ return super(Graph, self).nodes(data=data)
+
+ def edges_iter(self, nbunch=None, data=False, default=None):
+ """Returns an iterable object over the edges.
+
+ Type of iterable returned object depends on which version
+ of networkx is used. When networkx < 2.0 is used , method
+ returns an iterator, but if networkx > 2.0 is used, it returns
+ EdgeView of the Graph which is also iterable.
+ """
+ if misc.nx_version() == '1':
+ return super(Graph, self).edges_iter(nbunch=nbunch, data=data,
+ default=default)
+ return super(Graph, self).edges(nbunch=nbunch, data=data,
+ default=default)
+
+ def add_edge(self, u, v, attr_dict=None, **attr):
+ """Add an edge between u and v."""
+ if misc.nx_version() == '1':
+ return super(Graph, self).add_edge(u, v, attr_dict=attr_dict,
+ **attr)
+ if attr_dict is not None:
+ return super(Graph, self).add_edge(u, v, **attr_dict)
+ return super(Graph, self).add_edge(u, v, **attr)
+
+ def add_node(self, n, attr_dict=None, **attr):
+ """Add a single node n and update node attributes."""
+ if misc.nx_version() == '1':
+ return super(Graph, self).add_node(n, attr_dict=attr_dict, **attr)
+ if attr_dict is not None:
+ return super(Graph, self).add_node(n, **attr_dict)
+ return super(Graph, self).add_node(n, **attr)
+
+ def fresh_copy(self):
+ """Return a fresh copy graph with the same data structure.
+
+ A fresh copy has no nodes, edges or graph attributes. It is
+ the same data structure as the current graph. This method is
+ typically used to create an empty version of the graph.
+ """
+ return Graph()
+
class DiGraph(nx.DiGraph):
"""A directed graph subclass with useful utility functions."""
def __init__(self, data=None, name=''):
- super(DiGraph, self).__init__(name=name, data=data)
+ if misc.nx_version() == '1':
+ super(DiGraph, self).__init__(name=name, data=data)
+ else:
+ super(DiGraph, self).__init__(name=name, incoming_graph_data=data)
self.frozen = False
def freeze(self):
@@ -124,13 +184,13 @@ class DiGraph(nx.DiGraph):
def no_successors_iter(self):
"""Returns an iterator for all nodes with no successors."""
for n in self.nodes_iter():
- if not len(self.successors(n)):
+ if not len(list(self.successors(n))):
yield n
def no_predecessors_iter(self):
"""Returns an iterator for all nodes with no predecessors."""
for n in self.nodes_iter():
- if not len(self.predecessors(n)):
+ if not len(list(self.predecessors(n))):
yield n
def bfs_predecessors_iter(self, n):
@@ -153,6 +213,71 @@ class DiGraph(nx.DiGraph):
if pred_pred not in visited:
queue.append(pred_pred)
+ def add_edge(self, u, v, attr_dict=None, **attr):
+ """Add an edge between u and v."""
+ if misc.nx_version() == '1':
+ return super(DiGraph, self).add_edge(u, v, attr_dict=attr_dict,
+ **attr)
+ if attr_dict is not None:
+ return super(DiGraph, self).add_edge(u, v, **attr_dict)
+ return super(DiGraph, self).add_edge(u, v, **attr)
+
+ def add_node(self, n, attr_dict=None, **attr):
+ """Add a single node n and update node attributes."""
+ if misc.nx_version() == '1':
+ return super(DiGraph, self).add_node(n, attr_dict=attr_dict,
+ **attr)
+ if attr_dict is not None:
+ return super(DiGraph, self).add_node(n, **attr_dict)
+ return super(DiGraph, self).add_node(n, **attr)
+
+ def successors_iter(self, n):
+ """Returns an iterator over successor nodes of n."""
+ if misc.nx_version() == '1':
+ return super(DiGraph, self).successors_iter(n)
+ return super(DiGraph, self).successors(n)
+
+ def predecessors_iter(self, n):
+ """Return an iterator over predecessor nodes of n."""
+ if misc.nx_version() == '1':
+ return super(DiGraph, self).predecessors_iter(n)
+ return super(DiGraph, self).predecessors(n)
+
+ def nodes_iter(self, data=False):
+ """Returns an iterable object over the nodes.
+
+ Type of iterable returned object depends on which version
+ of networkx is used. When networkx < 2.0 is used , method
+ returns an iterator, but if networkx > 2.0 is used, it returns
+ NodeView of the Graph which is also iterable.
+ """
+ if misc.nx_version() == '1':
+ return super(DiGraph, self).nodes_iter(data=data)
+ return super(DiGraph, self).nodes(data=data)
+
+ def edges_iter(self, nbunch=None, data=False, default=None):
+ """Returns an iterable object over the edges.
+
+ Type of iterable returned object depends on which version
+ of networkx is used. When networkx < 2.0 is used , method
+ returns an iterator, but if networkx > 2.0 is used, it returns
+ EdgeView of the Graph which is also iterable.
+ """
+ if misc.nx_version() == '1':
+ return super(DiGraph, self).edges_iter(nbunch=nbunch, data=data,
+ default=default)
+ return super(DiGraph, self).edges(nbunch=nbunch, data=data,
+ default=default)
+
+ def fresh_copy(self):
+ """Return a fresh copy graph with the same data structure.
+
+ A fresh copy has no nodes, edges or graph attributes. It is
+ the same data structure as the current graph. This method is
+ typically used to create an empty version of the graph.
+ """
+ return DiGraph()
+
class OrderedDiGraph(DiGraph):
"""A directed graph subclass with useful utility functions.
@@ -162,9 +287,22 @@ class OrderedDiGraph(DiGraph):
order).
"""
node_dict_factory = collections.OrderedDict
- adjlist_dict_factory = collections.OrderedDict
+ if misc.nx_version() == '1':
+ adjlist_dict_factory = collections.OrderedDict
+ else:
+ adjlist_outer_dict_factory = collections.OrderedDict
+ adjlist_inner_dict_factory = collections.OrderedDict
edge_attr_dict_factory = collections.OrderedDict
+ def fresh_copy(self):
+ """Return a fresh copy graph with the same data structure.
+
+ A fresh copy has no nodes, edges or graph attributes. It is
+ the same data structure as the current graph. This method is
+ typically used to create an empty version of the graph.
+ """
+ return OrderedDiGraph()
+
class OrderedGraph(Graph):
"""A graph subclass with useful utility functions.
@@ -174,9 +312,22 @@ class OrderedGraph(Graph):
order).
"""
node_dict_factory = collections.OrderedDict
- adjlist_dict_factory = collections.OrderedDict
+ if misc.nx_version() == '1':
+ adjlist_dict_factory = collections.OrderedDict
+ else:
+ adjlist_outer_dict_factory = collections.OrderedDict
+ adjlist_inner_dict_factory = collections.OrderedDict
edge_attr_dict_factory = collections.OrderedDict
+ def fresh_copy(self):
+ """Return a fresh copy graph with the same data structure.
+
+ A fresh copy has no nodes, edges or graph attributes. It is
+ the same data structure as the current graph. This method is
+ typically used to create an empty version of the graph.
+ """
+ return OrderedGraph()
+
def merge_graphs(graph, *graphs, **kwargs):
"""Merges a bunch of graphs into a new graph.
diff --git a/taskflow/types/tree.py b/taskflow/types/tree.py
index 9269153..3681694 100644
--- a/taskflow/types/tree.py
+++ b/taskflow/types/tree.py
@@ -402,7 +402,7 @@ class Node(object):
"""
g = graph.OrderedDiGraph()
for node in self.bfs_iter(include_self=True, right_to_left=True):
- g.add_node(node.item, attr_dict=node.metadata)
+ g.add_node(node.item, **node.metadata)
if node is not self:
g.add_edge(node.parent.item, node.item)
return g
diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py
index 421449d..123ff89 100644
--- a/taskflow/utils/misc.py
+++ b/taskflow/utils/misc.py
@@ -27,6 +27,7 @@ import threading
import types
import enum
+import networkx as nx
from oslo_serialization import jsonutils
from oslo_serialization import msgpackutils
from oslo_utils import encodeutils
@@ -539,3 +540,7 @@ def safe_copy_dict(obj):
return {}
# default to a shallow copy to avoid most ownership issues
return dict(obj)
+
+
+def nx_version():
+ return nx.__version__.split('.')[0]
diff --git a/test-requirements.txt b/test-requirements.txt
index 6304248..2f6805a 100644
--- a/test-requirements.txt
+++ b/test-requirements.txt
@@ -13,7 +13,7 @@ redis>=2.10.0 # MIT
kombu!=4.0.2,>=4.0.0 # BSD
# eventlet
-eventlet!=0.18.3,!=0.20.1,!=0.21.0 # MIT
+eventlet!=0.18.3,!=0.20.1,!=0.21.0,>=0.18.2 # MIT
# database
SQLAlchemy!=1.1.5,!=1.1.6,!=1.1.7,!=1.1.8,>=1.0.10 # MIT
@@ -29,5 +29,6 @@ oslotest>=3.2.0 # Apache-2.0
mock>=2.0.0 # BSD
testtools>=2.2.0 # MIT
testscenarios>=0.4 # Apache-2.0/BSD
+testrepository>=0.0.18 # Apache-2.0/BSD
doc8>=0.6.0 # Apache-2.0
sphinx!=1.6.6,!=1.6.7,>=1.6.2 # BSD