summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlimi Qudirah <qudrohbidemi@gmail.com>2023-05-13 13:22:45 +0100
committerGitHub <noreply@github.com>2023-05-13 14:22:45 +0200
commit90bb1f0705f34c6570fe9fa537f672f9ddb3372c (patch)
tree4074bd018d91378bf54f27d6a8d9564f751d8553
parentc051696edfda4fb0cf27298259b723c57459ab60 (diff)
downloadnetworkx-90bb1f0705f34c6570fe9fa537f672f9ddb3372c.tar.gz
Improve test coverage for mst.py (#6540)
* fixes for 6539 * add more tests to test_mst.py
-rw-r--r--networkx/algorithms/tree/tests/test_mst.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/networkx/algorithms/tree/tests/test_mst.py b/networkx/algorithms/tree/tests/test_mst.py
index ee0c3f60..373f16cf 100644
--- a/networkx/algorithms/tree/tests/test_mst.py
+++ b/networkx/algorithms/tree/tests/test_mst.py
@@ -9,6 +9,10 @@ from networkx.utils import edges_equal, nodes_equal
def test_unknown_algorithm():
with pytest.raises(ValueError):
nx.minimum_spanning_tree(nx.Graph(), algorithm="random")
+ with pytest.raises(
+ ValueError, match="random is not a valid choice for an algorithm."
+ ):
+ nx.maximum_spanning_edges(nx.Graph(), algorithm="random")
class MinimumSpanningTreeTestBase:
@@ -104,6 +108,19 @@ class MinimumSpanningTreeTestBase:
with pytest.raises(ValueError):
list(edges)
+ def test_nan_weights_MultiGraph(self):
+ G = nx.MultiGraph()
+ G.add_edge(0, 12, weight=float("nan"))
+ edges = nx.minimum_spanning_edges(
+ G, algorithm="prim", data=False, ignore_nan=False
+ )
+ with pytest.raises(ValueError):
+ list(edges)
+ # test default for ignore_nan as False
+ edges = nx.minimum_spanning_edges(G, algorithm="prim", data=False)
+ with pytest.raises(ValueError):
+ list(edges)
+
def test_nan_weights_order(self):
# now try again with a nan edge at the beginning of G.nodes
edges = [
@@ -283,6 +300,14 @@ class TestKruskal(MultigraphMSTTestBase):
)
assert edges_equal([(1, 2), (2, 3)], list(mst_edges))
+ # both keys and data are included
+ mst_edges = nx.minimum_spanning_edges(
+ G, algorithm=self.algo, keys=True, data=True
+ )
+ assert edges_equal(
+ [(1, 2, 1, {"weight": 2}), (2, 3, 1, {"weight": 2})], list(mst_edges)
+ )
+
class TestPrim(MultigraphMSTTestBase):
"""Unit tests for computing a minimum (or maximum) spanning tree
@@ -291,6 +316,18 @@ class TestPrim(MultigraphMSTTestBase):
algorithm = "prim"
+ def test_prim_mst_edges_simple_graph(self):
+ H = nx.Graph()
+ H.add_edge(1, 2, key=2, weight=3)
+ H.add_edge(3, 2, key=1, weight=2)
+ H.add_edge(3, 1, key=1, weight=4)
+
+ mst_edges = nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=True)
+ assert edges_equal(
+ [(1, 2, {"key": 2, "weight": 3}), (2, 3, {"key": 1, "weight": 2})],
+ list(mst_edges),
+ )
+
def test_ignore_nan(self):
"""Tests that the edges with NaN weights are ignored or
raise an Error based on ignore_nan is true or false"""