summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoras1371 <sarraf.artin@gmail.com>2021-05-19 16:23:27 -1000
committerGitHub <noreply@github.com>2021-05-19 19:23:27 -0700
commite8914bb5681b6fad8a6764406d3c7a78ebc582ae (patch)
treeb3add2e0e9fcdf13be780573878a81eae9f3cdb2
parentd04265658e17ad40fb928677fc220675ef8bf09f (diff)
downloadnetworkx-e8914bb5681b6fad8a6764406d3c7a78ebc582ae.tar.gz
Add topological_generations function (#4757)
Adds a topological_generations function and refactor topological_sort to yield from topological_generations. Co-authored-by: Ross Barnowski <rossbar@berkeley.edu> Co-authored-by: Dan Schult <dschult@colgate.edu>
-rw-r--r--doc/reference/algorithms/dag.rst1
-rw-r--r--networkx/algorithms/dag.py114
-rw-r--r--networkx/algorithms/tests/test_dag.py34
3 files changed, 119 insertions, 30 deletions
diff --git a/doc/reference/algorithms/dag.rst b/doc/reference/algorithms/dag.rst
index 19edbcfa..f1cadf8a 100644
--- a/doc/reference/algorithms/dag.rst
+++ b/doc/reference/algorithms/dag.rst
@@ -9,6 +9,7 @@ Directed Acyclic Graphs
ancestors
descendants
topological_sort
+ topological_generations
all_topological_sorts
lexicographical_topological_sort
is_directed_acyclic_graph
diff --git a/networkx/algorithms/dag.py b/networkx/algorithms/dag.py
index 4b76f4f7..7b38e5ea 100644
--- a/networkx/algorithms/dag.py
+++ b/networkx/algorithms/dag.py
@@ -20,6 +20,7 @@ __all__ = [
"topological_sort",
"lexicographical_topological_sort",
"all_topological_sorts",
+ "topological_generations",
"is_directed_acyclic_graph",
"is_aperiodic",
"transitive_closure",
@@ -101,6 +102,83 @@ def is_directed_acyclic_graph(G):
return G.is_directed() and not has_cycle(G)
+def topological_generations(G):
+ """Stratifies a DAG into generations.
+
+ A topological generation is node collection in which ancestors of a node in each
+ generation are guaranteed to be in a previous generation, and any descendants of
+ a node are guaranteed to be in a following generation. Nodes are guaranteed to
+ be in the earliest possible generation that they can belong to.
+
+ Parameters
+ ----------
+ G : NetworkX digraph
+ A directed acyclic graph (DAG)
+
+ Yields
+ ------
+ sets of nodes
+ Yields sets of nodes representing each generation.
+
+ Raises
+ ------
+ NetworkXError
+ Generations are defined for directed graphs only. If the graph
+ `G` is undirected, a :exc:`NetworkXError` is raised.
+
+ NetworkXUnfeasible
+ If `G` is not a directed acyclic graph (DAG) no topological generations
+ exist and a :exc:`NetworkXUnfeasible` exception is raised. This can also
+ be raised if `G` is changed while the returned iterator is being processed
+
+ RuntimeError
+ If `G` is changed while the returned iterator is being processed.
+
+ Examples
+ --------
+ >>> DG = nx.DiGraph([(2, 1), (3, 1)])
+ >>> [sorted(generation) for generation in nx.topological_generations(DG)]
+ [[2, 3], [1]]
+
+ Notes
+ -----
+ The generation in which a node resides can also be determined by taking the
+ max-path-distance from the node to the farthest leaf node. That value can
+ be obtained with this function using `enumerate(topological_generations(G))`.
+
+ See also
+ --------
+ topological_sort
+ """
+ if not G.is_directed():
+ raise nx.NetworkXError("Topological sort not defined on undirected graphs.")
+
+ multigraph = G.is_multigraph()
+ indegree_map = {v: d for v, d in G.in_degree() if d > 0}
+ zero_indegree = [v for v, d in G.in_degree() if d == 0]
+
+ while zero_indegree:
+ this_generation = zero_indegree
+ zero_indegree = []
+ for node in this_generation:
+ if node not in G:
+ raise RuntimeError("Graph changed during iteration")
+ for child in G.neighbors(node):
+ try:
+ indegree_map[child] -= len(G[node][child]) if multigraph else 1
+ except KeyError as e:
+ raise RuntimeError("Graph changed during iteration") from e
+ if indegree_map[child] == 0:
+ zero_indegree.append(child)
+ del indegree_map[child]
+ yield this_generation
+
+ if indegree_map:
+ raise nx.NetworkXUnfeasible(
+ "Graph contains a cycle or graph changed during iteration"
+ )
+
+
def topological_sort(G):
"""Returns a generator of nodes in topologically sorted order.
@@ -114,10 +192,10 @@ def topological_sort(G):
G : NetworkX digraph
A directed acyclic graph (DAG)
- Returns
- -------
- iterable
- An iterable of node names in topological sorted order.
+ Yields
+ ------
+ nodes
+ Yields the nodes in topological sorted order.
Raises
------
@@ -165,32 +243,8 @@ def topological_sort(G):
.. [1] Manber, U. (1989).
*Introduction to Algorithms - A Creative Approach.* Addison-Wesley.
"""
- if not G.is_directed():
- raise nx.NetworkXError("Topological sort not defined on undirected graphs.")
-
- indegree_map = {v: d for v, d in G.in_degree() if d > 0}
- # These nodes have zero indegree and ready to be returned.
- zero_indegree = [v for v, d in G.in_degree() if d == 0]
-
- while zero_indegree:
- node = zero_indegree.pop()
- if node not in G:
- raise RuntimeError("Graph changed during iteration")
- for _, child in G.edges(node):
- try:
- indegree_map[child] -= 1
- except KeyError as e:
- raise RuntimeError("Graph changed during iteration") from e
- if indegree_map[child] == 0:
- zero_indegree.append(child)
- del indegree_map[child]
-
- yield node
-
- if indegree_map:
- raise nx.NetworkXUnfeasible(
- "Graph contains a cycle or graph changed " "during iteration"
- )
+ for generation in nx.topological_generations(G):
+ yield from generation
def lexicographical_topological_sort(G, key=None):
diff --git a/networkx/algorithms/tests/test_dag.py b/networkx/algorithms/tests/test_dag.py
index 8e1d7cc8..c0fac32c 100644
--- a/networkx/algorithms/tests/test_dag.py
+++ b/networkx/algorithms/tests/test_dag.py
@@ -478,6 +478,40 @@ class TestDAG:
assert sorting == test_nodes
+def test_topological_generations():
+ G = nx.DiGraph(
+ {
+ 1: [2, 3],
+ 2: [4, 5],
+ 3: [7],
+ 4: [],
+ 5: [6, 7],
+ 6: [],
+ 7: [],
+ }
+ ).reverse()
+ # order within each generation is inconsequential
+ generations = [sorted(gen) for gen in nx.topological_generations(G)]
+ expected = [[4, 6, 7], [3, 5], [2], [1]]
+ assert generations == expected
+
+ MG = nx.MultiDiGraph(G.edges)
+ MG.add_edge(2, 1)
+ generations = [sorted(gen) for gen in nx.topological_generations(MG)]
+ assert generations == expected
+
+
+def test_topological_generations_empty():
+ G = nx.DiGraph()
+ assert list(nx.topological_generations(G)) == []
+
+
+def test_topological_generations_cycle():
+ G = nx.DiGraph([[2, 1], [3, 1], [1, 2]])
+ with pytest.raises(nx.NetworkXUnfeasible):
+ list(nx.topological_generations(G))
+
+
def test_is_aperiodic_cycle():
G = nx.DiGraph()
nx.add_cycle(G, [1, 2, 3, 4])