diff options
author | Alimi Qudirah <qudrohbidemi@gmail.com> | 2023-05-13 13:22:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-13 14:22:45 +0200 |
commit | 90bb1f0705f34c6570fe9fa537f672f9ddb3372c (patch) | |
tree | 4074bd018d91378bf54f27d6a8d9564f751d8553 | |
parent | c051696edfda4fb0cf27298259b723c57459ab60 (diff) | |
download | networkx-90bb1f0705f34c6570fe9fa537f672f9ddb3372c.tar.gz |
Improve test coverage for mst.py (#6540)
* fixes for 6539
* add more tests to test_mst.py
-rw-r--r-- | networkx/algorithms/tree/tests/test_mst.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/networkx/algorithms/tree/tests/test_mst.py b/networkx/algorithms/tree/tests/test_mst.py index ee0c3f60..373f16cf 100644 --- a/networkx/algorithms/tree/tests/test_mst.py +++ b/networkx/algorithms/tree/tests/test_mst.py @@ -9,6 +9,10 @@ from networkx.utils import edges_equal, nodes_equal def test_unknown_algorithm(): with pytest.raises(ValueError): nx.minimum_spanning_tree(nx.Graph(), algorithm="random") + with pytest.raises( + ValueError, match="random is not a valid choice for an algorithm." + ): + nx.maximum_spanning_edges(nx.Graph(), algorithm="random") class MinimumSpanningTreeTestBase: @@ -104,6 +108,19 @@ class MinimumSpanningTreeTestBase: with pytest.raises(ValueError): list(edges) + def test_nan_weights_MultiGraph(self): + G = nx.MultiGraph() + G.add_edge(0, 12, weight=float("nan")) + edges = nx.minimum_spanning_edges( + G, algorithm="prim", data=False, ignore_nan=False + ) + with pytest.raises(ValueError): + list(edges) + # test default for ignore_nan as False + edges = nx.minimum_spanning_edges(G, algorithm="prim", data=False) + with pytest.raises(ValueError): + list(edges) + def test_nan_weights_order(self): # now try again with a nan edge at the beginning of G.nodes edges = [ @@ -283,6 +300,14 @@ class TestKruskal(MultigraphMSTTestBase): ) assert edges_equal([(1, 2), (2, 3)], list(mst_edges)) + # both keys and data are included + mst_edges = nx.minimum_spanning_edges( + G, algorithm=self.algo, keys=True, data=True + ) + assert edges_equal( + [(1, 2, 1, {"weight": 2}), (2, 3, 1, {"weight": 2})], list(mst_edges) + ) + class TestPrim(MultigraphMSTTestBase): """Unit tests for computing a minimum (or maximum) spanning tree @@ -291,6 +316,18 @@ class TestPrim(MultigraphMSTTestBase): algorithm = "prim" + def test_prim_mst_edges_simple_graph(self): + H = nx.Graph() + H.add_edge(1, 2, key=2, weight=3) + H.add_edge(3, 2, key=1, weight=2) + H.add_edge(3, 1, key=1, weight=4) + + mst_edges = nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=True) + assert edges_equal( + [(1, 2, {"key": 2, "weight": 3}), (2, 3, {"key": 1, "weight": 2})], + list(mst_edges), + ) + def test_ignore_nan(self): """Tests that the edges with NaN weights are ignored or raise an Error based on ignore_nan is true or false""" |