summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJim Kitchen <jim22k@users.noreply.github.com>2023-03-28 18:17:46 -0500
committerGitHub <noreply@github.com>2023-03-28 19:17:46 -0400
commitabd0a82d06c0ae8ecc1c60e5a9bde4c0dee08d48 (patch)
treef84c8b95d9fbbaaf3bedacbc4dacefc0587e192c
parent3aad1e8974e5e457adbe37fc55ac9755b5f7a388 (diff)
downloadnetworkx-abd0a82d06c0ae8ecc1c60e5a9bde4c0dee08d48.tar.gz
Test dispatching via nx-loopback backend (#6536)
* Add tests for nx._dispatch decorator The dispatch functionality is used to delegate graph computations to a different backend. Because those backends are not part of NetworkX, testing the dispatching feature was not originally added, relying instead on the other backends (e.g. graphblas-algorithms) to verify the dispatch functionality is working. This change creates a "loopback" backend where NetworkX dispatches to itself for the sole purpose of exercising the dispatching machinery. In one incarnation, various tests are augmented to use the LoopbackGraph family and force loopback dispatching to occur as normal usage would. A second incarnation forces *all* tests to run in dispatch mode but use of a different _dispatch decorator. This mode is triggered for all of pytest, so it must be tested by the CI system specifically. * Update CI to hopefully run dispatching auto tests * Formatting * More formatting fixes * Better comments explaining dispatching tests
-rw-r--r--.github/workflows/test.yml6
-rw-r--r--networkx/algorithms/components/tests/test_connected.py8
-rw-r--r--networkx/algorithms/link_analysis/tests/test_pagerank.py8
-rw-r--r--networkx/algorithms/tests/test_structuralholes.py8
-rw-r--r--networkx/classes/backends.py13
-rw-r--r--networkx/classes/tests/dispatch_interface.py83
-rw-r--r--setup.py6
7 files changed, 124 insertions, 8 deletions
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 20f9ecfe..cc82b101 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -55,6 +55,12 @@ jobs:
run: |
pytest --doctest-modules --durations=10 --pyargs networkx
+ - name: Test Dispatching
+ # Limit this to only a single combination from the matrix
+ if: ${{ (matrix.os == 'ubuntu') && (matrix.python-version == '3.11') }}
+ run: |
+ NETWORKX_GRAPH_CONVERT=nx-loopback pytest --doctest-modules --durations=10 --pyargs networkx
+
extra:
runs-on: ${{ matrix.os }}
strategy:
diff --git a/networkx/algorithms/components/tests/test_connected.py b/networkx/algorithms/components/tests/test_connected.py
index bf3954e9..4c9b8d28 100644
--- a/networkx/algorithms/components/tests/test_connected.py
+++ b/networkx/algorithms/components/tests/test_connected.py
@@ -3,6 +3,7 @@ import pytest
import networkx as nx
from networkx import NetworkXNotImplemented
from networkx import convert_node_labels_to_integers as cnlti
+from networkx.classes.tests import dispatch_interface
class TestConnected:
@@ -60,9 +61,12 @@ class TestConnected:
C = []
cls.gc.append((G, C))
- def test_connected_components(self):
+ # This additionally tests the @nx._dispatch mechanism, treating
+ # nx.connected_components as if it were a re-implementation from another package
+ @pytest.mark.parametrize("wrapper", [lambda x: x, dispatch_interface.convert])
+ def test_connected_components(self, wrapper):
cc = nx.connected_components
- G = self.G
+ G = wrapper(self.G)
C = {
frozenset([0, 1, 2, 3]),
frozenset([4, 5, 6, 7, 8, 9]),
diff --git a/networkx/algorithms/link_analysis/tests/test_pagerank.py b/networkx/algorithms/link_analysis/tests/test_pagerank.py
index fa73493b..6a30f0cd 100644
--- a/networkx/algorithms/link_analysis/tests/test_pagerank.py
+++ b/networkx/algorithms/link_analysis/tests/test_pagerank.py
@@ -3,6 +3,7 @@ import random
import pytest
import networkx as nx
+from networkx.classes.tests import dispatch_interface
np = pytest.importorskip("numpy")
pytest.importorskip("scipy")
@@ -82,8 +83,11 @@ class TestPageRank:
for n in G:
assert p[n] == pytest.approx(G.pagerank[n], abs=1e-4)
- def test_google_matrix(self):
- G = self.G
+ # This additionally tests the @nx._dispatch mechanism, treating
+ # nx.google_matrix as if it were a re-implementation from another package
+ @pytest.mark.parametrize("wrapper", [lambda x: x, dispatch_interface.convert])
+ def test_google_matrix(self, wrapper):
+ G = wrapper(self.G)
M = nx.google_matrix(G, alpha=0.9, nodelist=sorted(G))
_, ev = np.linalg.eig(M.T)
p = ev[:, 0] / ev[:, 0].sum()
diff --git a/networkx/algorithms/tests/test_structuralholes.py b/networkx/algorithms/tests/test_structuralholes.py
index 51447fee..6f92baa4 100644
--- a/networkx/algorithms/tests/test_structuralholes.py
+++ b/networkx/algorithms/tests/test_structuralholes.py
@@ -4,6 +4,7 @@ import math
import pytest
import networkx as nx
+from networkx.classes.tests import dispatch_interface
class TestStructuralHoles:
@@ -51,8 +52,11 @@ class TestStructuralHoles:
("G", "C"): 10,
}
- def test_constraint_directed(self):
- constraint = nx.constraint(self.D)
+ # This additionally tests the @nx._dispatch mechanism, treating
+ # nx.mutual_weight as if it were a re-implementation from another package
+ @pytest.mark.parametrize("wrapper", [lambda x: x, dispatch_interface.convert])
+ def test_constraint_directed(self, wrapper):
+ constraint = nx.constraint(wrapper(self.D))
assert constraint[0] == pytest.approx(1.003, abs=1e-3)
assert constraint[1] == pytest.approx(1.003, abs=1e-3)
assert constraint[2] == pytest.approx(1.389, abs=1e-3)
diff --git a/networkx/classes/backends.py b/networkx/classes/backends.py
index 89a200d4..761d0a4b 100644
--- a/networkx/classes/backends.py
+++ b/networkx/classes/backends.py
@@ -53,8 +53,8 @@ tests are working, while avoiding causing an explicit failure.
A special ``on_start_tests(items)`` function may be defined by the backend.
It will be called with the list of NetworkX tests discovered. Each item
-is a `pytest.Node` object. If the backend does not support the test, that
-test can be marked as xfail.
+is a test object that can be marked as xfail if the backend does not support
+the test using `item.add_marker(pytest.mark.xfail(reason=...))`.
"""
import functools
import inspect
@@ -147,6 +147,10 @@ def _dispatch(func=None, *, name=None):
)
return func(*args, **kwds)
+ # Keep a handle to the original function to use when testing
+ # the dispatch mechanism internally
+ wrapper._orig_func = func
+
_register_algo(name, wrapper)
return wrapper
@@ -171,6 +175,10 @@ def test_override_dispatch(func=None, *, name=None):
def wrapper(*args, **kwds):
backend = plugins[plugin_name].load()
if not hasattr(backend, name):
+ if plugin_name == "nx-loopback":
+ raise NetworkXNotImplemented(
+ f"'{name}' not found in {backend.__class__.__name__}"
+ )
pytest.xfail(f"'{name}' not implemented by {plugin_name}")
bound = sig.bind(*args, **kwds)
bound.apply_defaults()
@@ -196,6 +204,7 @@ def test_override_dispatch(func=None, *, name=None):
result = getattr(backend, name).__call__(graph, *args, **kwds)
return backend.convert_to_nx(result, name=name)
+ wrapper._orig_func = func
_register_algo(name, wrapper)
return wrapper
diff --git a/networkx/classes/tests/dispatch_interface.py b/networkx/classes/tests/dispatch_interface.py
new file mode 100644
index 00000000..ded79b36
--- /dev/null
+++ b/networkx/classes/tests/dispatch_interface.py
@@ -0,0 +1,83 @@
+# This file contains utilities for testing the dispatching feature
+
+# A full test of all dispatchable algorithms is performed by
+# modifying the pytest invocation and setting an environment variable
+# NETWORKX_GRAPH_CONVERT=nx-loopback pytest
+# This is comprehensive, but only tests the `test_override_dispatch`
+# function in networkx.classes.backends.
+
+# To test the `_dispatch` function directly, several tests scattered throughout
+# NetworkX have been augmented to test normal and dispatch mode.
+# Searching for `dispatch_interface` should locate the specific tests.
+
+import networkx as nx
+from networkx import DiGraph, Graph, MultiDiGraph, MultiGraph, PlanarEmbedding
+
+
+class LoopbackGraph(Graph):
+ __networkx_plugin__ = "nx-loopback"
+
+
+class LoopbackDiGraph(DiGraph):
+ __networkx_plugin__ = "nx-loopback"
+
+
+class LoopbackMultiGraph(MultiGraph):
+ __networkx_plugin__ = "nx-loopback"
+
+
+class LoopbackMultiDiGraph(MultiDiGraph):
+ __networkx_plugin__ = "nx-loopback"
+
+
+class LoopbackPlanarEmbedding(PlanarEmbedding):
+ __networkx_plugin__ = "nx-loopback"
+
+
+def convert(graph):
+ if isinstance(graph, PlanarEmbedding):
+ return LoopbackPlanarEmbedding(graph)
+ if isinstance(graph, MultiDiGraph):
+ return LoopbackMultiDiGraph(graph)
+ if isinstance(graph, MultiGraph):
+ return LoopbackMultiGraph(graph)
+ if isinstance(graph, DiGraph):
+ return LoopbackDiGraph(graph)
+ if isinstance(graph, Graph):
+ return LoopbackGraph(graph)
+ raise TypeError(f"Unsupported type of graph: {type(graph)}")
+
+
+class LoopbackDispatcher:
+ non_toplevel = {
+ "inter_community_edges": nx.community.quality.inter_community_edges,
+ "is_tournament": nx.algorithms.tournament.is_tournament,
+ "mutual_weight": nx.algorithms.structuralholes.mutual_weight,
+ "score_sequence": nx.algorithms.tournament.score_sequence,
+ "tournament_matrix": nx.algorithms.tournament.tournament_matrix,
+ }
+
+ def __getattr__(self, item):
+ # Return the original, undecorated NetworkX algorithm
+ if hasattr(nx, item):
+ return getattr(nx, item)._orig_func
+ if item in self.non_toplevel:
+ return self.non_toplevel[item]._orig_func
+ raise AttributeError(item)
+
+ @staticmethod
+ def convert_from_nx(graph, weight=None, *, name=None):
+ return graph
+
+ @staticmethod
+ def convert_to_nx(obj, *, name=None):
+ return obj
+
+ @staticmethod
+ def on_start_tests(items):
+ # Verify that items can be xfailed
+ for item in items:
+ assert hasattr(item, "add_marker")
+
+
+dispatcher = LoopbackDispatcher()
diff --git a/setup.py b/setup.py
index c40405a5..30a75746 100644
--- a/setup.py
+++ b/setup.py
@@ -151,6 +151,11 @@ package_data = {
"networkx.utils": ["tests/*.py"],
}
+# Loopback dispatcher required for testing nx._dispatch decorator
+entry_points = {
+ "networkx.plugins": "nx-loopback = networkx.classes.tests.dispatch_interface:dispatcher"
+}
+
def parse_requirements_file(filename):
with open(filename) as fid:
@@ -188,6 +193,7 @@ if __name__ == "__main__":
package_data=package_data,
install_requires=install_requires,
extras_require=extras_require,
+ entry_points=entry_points,
python_requires=">=3.8",
zip_safe=False,
)