summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJim Kitchen <jim22k@gmail.com>2022-10-11 15:24:14 -0500
committerMridul Seth <git@mriduls.com>2022-10-12 12:05:14 +0400
commitc2111a36841d5e43b6beae1fd667996d35646e30 (patch)
tree38fa5e1316ca45a9906169ceb7ac21bfc52a612e
parent71434d674cf8ec6c3007dd41b78ee6f407e9b4eb (diff)
downloadnetworkx-nx-sparse.tar.gz
Allow dispatcher decorator without a namenx-sparse
- Name is taken from the decorated function - Raise error if backend doesn't implement a decorated function which is called - Check for duplicate names for dispatching algorithms
-rw-r--r--networkx/algorithms/boundary.py4
-rw-r--r--networkx/algorithms/cluster.py6
-rw-r--r--networkx/algorithms/core.py2
-rw-r--r--networkx/algorithms/cuts.py16
-rw-r--r--networkx/algorithms/dag.py4
-rw-r--r--networkx/algorithms/dominating.py2
-rw-r--r--networkx/algorithms/isolate.py6
-rw-r--r--networkx/algorithms/link_analysis/pagerank_alg.py2
-rw-r--r--networkx/algorithms/reciprocity.py4
-rw-r--r--networkx/algorithms/regular.py4
-rw-r--r--networkx/algorithms/simple_paths.py2
-rw-r--r--networkx/algorithms/smetric.py2
-rw-r--r--networkx/algorithms/structuralholes.py2
-rw-r--r--networkx/algorithms/tournament.py6
-rw-r--r--networkx/algorithms/triads.py2
-rw-r--r--networkx/classes/backends.py160
16 files changed, 135 insertions, 89 deletions
diff --git a/networkx/algorithms/boundary.py b/networkx/algorithms/boundary.py
index c74f47e9..068ce645 100644
--- a/networkx/algorithms/boundary.py
+++ b/networkx/algorithms/boundary.py
@@ -14,7 +14,7 @@ from itertools import chain
__all__ = ["edge_boundary", "node_boundary"]
-@nx.dispatch("edge_boundary")
+@nx.dispatch
def edge_boundary(G, nbunch1, nbunch2=None, data=False, keys=False, default=None):
"""Returns the edge boundary of `nbunch1`.
@@ -91,7 +91,7 @@ def edge_boundary(G, nbunch1, nbunch2=None, data=False, keys=False, default=None
)
-@nx.dispatch("node_boundary")
+@nx.dispatch()
def node_boundary(G, nbunch1, nbunch2=None):
"""Returns the node boundary of `nbunch1`.
diff --git a/networkx/algorithms/cluster.py b/networkx/algorithms/cluster.py
index 159aba78..83596c94 100644
--- a/networkx/algorithms/cluster.py
+++ b/networkx/algorithms/cluster.py
@@ -220,7 +220,7 @@ def _directed_weighted_triangles_and_degree_iter(G, nodes=None, weight="weight")
yield (i, dtotal, dbidirectional, directed_triangles)
-@nx.dispatch("average_clustering")
+@nx.dispatch(name="average_clustering")
def average_clustering(G, nodes=None, weight=None, count_zeros=True):
r"""Compute the average clustering coefficient for the graph G.
@@ -280,7 +280,7 @@ def average_clustering(G, nodes=None, weight=None, count_zeros=True):
return sum(c) / len(c)
-@nx.dispatch("clustering")
+@nx.dispatch(name="clustering")
def clustering(G, nodes=None, weight=None):
r"""Compute the clustering coefficient for nodes.
@@ -433,7 +433,7 @@ def transitivity(G):
return 0 if triangles == 0 else triangles / contri
-@nx.dispatch("square_clustering")
+@nx.dispatch(name="square_clustering")
def square_clustering(G, nodes=None):
r"""Compute the squares clustering coefficient for nodes.
diff --git a/networkx/algorithms/core.py b/networkx/algorithms/core.py
index ac339554..faf1f0d0 100644
--- a/networkx/algorithms/core.py
+++ b/networkx/algorithms/core.py
@@ -378,7 +378,7 @@ def k_corona(G, k, core_number=None):
return _core_subgraph(G, func, k, core_number)
-@nx.dispatch("k_truss")
+@nx.dispatch
@not_implemented_for("directed")
@not_implemented_for("multigraph")
def k_truss(G, k):
diff --git a/networkx/algorithms/cuts.py b/networkx/algorithms/cuts.py
index 271ff309..2b4794ef 100644
--- a/networkx/algorithms/cuts.py
+++ b/networkx/algorithms/cuts.py
@@ -21,7 +21,7 @@ __all__ = [
# TODO STILL NEED TO UPDATE ALL THE DOCUMENTATION!
-@nx.dispatch("cut_size")
+@nx.dispatch
def cut_size(G, S, T=None, weight=None):
"""Returns the size of the cut between two sets of nodes.
@@ -84,7 +84,7 @@ def cut_size(G, S, T=None, weight=None):
return sum(weight for u, v, weight in edges)
-@nx.dispatch("volume")
+@nx.dispatch
def volume(G, S, weight=None):
"""Returns the volume of a set of nodes.
@@ -127,7 +127,7 @@ def volume(G, S, weight=None):
return sum(d for v, d in degree(S, weight=weight))
-@nx.dispatch("normalized_cut_size")
+@nx.dispatch
def normalized_cut_size(G, S, T=None, weight=None):
"""Returns the normalized size of the cut between two sets of nodes.
@@ -180,7 +180,7 @@ def normalized_cut_size(G, S, T=None, weight=None):
return num_cut_edges * ((1 / volume_S) + (1 / volume_T))
-@nx.dispatch("conductance")
+@nx.dispatch
def conductance(G, S, T=None, weight=None):
"""Returns the conductance of two sets of nodes.
@@ -228,7 +228,7 @@ def conductance(G, S, T=None, weight=None):
return num_cut_edges / min(volume_S, volume_T)
-@nx.dispatch("edge_expansion")
+@nx.dispatch
def edge_expansion(G, S, T=None, weight=None):
"""Returns the edge expansion between two node sets.
@@ -275,7 +275,7 @@ def edge_expansion(G, S, T=None, weight=None):
return num_cut_edges / min(len(S), len(T))
-@nx.dispatch("mixing_expansion")
+@nx.dispatch
def mixing_expansion(G, S, T=None, weight=None):
"""Returns the mixing expansion between two node sets.
@@ -323,7 +323,7 @@ def mixing_expansion(G, S, T=None, weight=None):
# TODO What is the generalization to two arguments, S and T? Does the
# denominator become `min(len(S), len(T))`?
-@nx.dispatch("node_expansion")
+@nx.dispatch
def node_expansion(G, S):
"""Returns the node expansion of the set `S`.
@@ -363,7 +363,7 @@ def node_expansion(G, S):
# TODO What is the generalization to two arguments, S and T? Does the
# denominator become `min(len(S), len(T))`?
-@nx.dispatch("boundary_expansion")
+@nx.dispatch
def boundary_expansion(G, S):
"""Returns the boundary expansion of the set `S`.
diff --git a/networkx/algorithms/dag.py b/networkx/algorithms/dag.py
index f7692d1f..381f079c 100644
--- a/networkx/algorithms/dag.py
+++ b/networkx/algorithms/dag.py
@@ -36,7 +36,7 @@ __all__ = [
chaini = chain.from_iterable
-@nx.dispatch("descendants")
+@nx.dispatch
def descendants(G, source):
"""Returns all nodes reachable from `source` in `G`.
@@ -73,7 +73,7 @@ def descendants(G, source):
return {child for parent, child in nx.bfs_edges(G, source)}
-@nx.dispatch("ancestors")
+@nx.dispatch
def ancestors(G, source):
"""Returns all nodes having a path to `source` in `G`.
diff --git a/networkx/algorithms/dominating.py b/networkx/algorithms/dominating.py
index 3f2da907..f80454b7 100644
--- a/networkx/algorithms/dominating.py
+++ b/networkx/algorithms/dominating.py
@@ -64,7 +64,7 @@ def dominating_set(G, start_with=None):
return dominating_set
-@nx.dispatch("is_dominating_set")
+@nx.dispatch
def is_dominating_set(G, nbunch):
"""Checks if `nbunch` is a dominating set for `G`.
diff --git a/networkx/algorithms/isolate.py b/networkx/algorithms/isolate.py
index d59a46c4..0a077177 100644
--- a/networkx/algorithms/isolate.py
+++ b/networkx/algorithms/isolate.py
@@ -6,7 +6,7 @@ import networkx as nx
__all__ = ["is_isolate", "isolates", "number_of_isolates"]
-@nx.dispatch("is_isolate")
+@nx.dispatch
def is_isolate(G, n):
"""Determines whether a node is an isolate.
@@ -39,7 +39,7 @@ def is_isolate(G, n):
return G.degree(n) == 0
-@nx.dispatch("isolates")
+@nx.dispatch
def isolates(G):
"""Iterator over isolates in the graph.
@@ -85,7 +85,7 @@ def isolates(G):
return (n for n, d in G.degree() if d == 0)
-@nx.dispatch("number_of_isolates")
+@nx.dispatch
def number_of_isolates(G):
"""Returns the number of isolates in the graph.
diff --git a/networkx/algorithms/link_analysis/pagerank_alg.py b/networkx/algorithms/link_analysis/pagerank_alg.py
index 346fce03..03054322 100644
--- a/networkx/algorithms/link_analysis/pagerank_alg.py
+++ b/networkx/algorithms/link_analysis/pagerank_alg.py
@@ -6,7 +6,7 @@ import networkx as nx
__all__ = ["pagerank", "google_matrix"]
-@nx.dispatch("pagerank")
+@nx.dispatch
def pagerank(
G,
alpha=0.85,
diff --git a/networkx/algorithms/reciprocity.py b/networkx/algorithms/reciprocity.py
index 9819de8b..660d922f 100644
--- a/networkx/algorithms/reciprocity.py
+++ b/networkx/algorithms/reciprocity.py
@@ -7,7 +7,7 @@ from ..utils import not_implemented_for
__all__ = ["reciprocity", "overall_reciprocity"]
-@nx.dispatch("reciprocity")
+@nx.dispatch
@not_implemented_for("undirected", "multigraph")
def reciprocity(G, nodes=None):
r"""Compute the reciprocity in a directed graph.
@@ -75,7 +75,7 @@ def _reciprocity_iter(G, nodes):
yield (node, reciprocity)
-@nx.dispatch("overall_reciprocity")
+@nx.dispatch
@not_implemented_for("undirected", "multigraph")
def overall_reciprocity(G):
"""Compute the reciprocity for the whole graph.
diff --git a/networkx/algorithms/regular.py b/networkx/algorithms/regular.py
index 89b98742..7c0b2709 100644
--- a/networkx/algorithms/regular.py
+++ b/networkx/algorithms/regular.py
@@ -5,7 +5,7 @@ from networkx.utils import not_implemented_for
__all__ = ["is_regular", "is_k_regular", "k_factor"]
-@nx.dispatch("is_regular")
+@nx.dispatch
def is_regular(G):
"""Determines whether the graph ``G`` is a regular graph.
@@ -41,7 +41,7 @@ def is_regular(G):
return in_regular and out_regular
-@nx.dispatch("is_k_regular")
+@nx.dispatch
@not_implemented_for("directed")
def is_k_regular(G, k):
"""Determines whether the graph ``G`` is a k-regular graph.
diff --git a/networkx/algorithms/simple_paths.py b/networkx/algorithms/simple_paths.py
index 0ddacabb..6284f961 100644
--- a/networkx/algorithms/simple_paths.py
+++ b/networkx/algorithms/simple_paths.py
@@ -13,7 +13,7 @@ __all__ = [
]
-@nx.dispatch("is_simple_path")
+@nx.dispatch
def is_simple_path(G, nodes):
"""Returns True if and only if `nodes` form a simple path in `G`.
diff --git a/networkx/algorithms/smetric.py b/networkx/algorithms/smetric.py
index a97bb8d9..a440aff1 100644
--- a/networkx/algorithms/smetric.py
+++ b/networkx/algorithms/smetric.py
@@ -3,7 +3,7 @@ import networkx as nx
__all__ = ["s_metric"]
-@nx.dispatch("s_metric")
+@nx.dispatch
def s_metric(G, normalized=True):
"""Returns the s-metric of graph.
diff --git a/networkx/algorithms/structuralholes.py b/networkx/algorithms/structuralholes.py
index 9762587e..417ba028 100644
--- a/networkx/algorithms/structuralholes.py
+++ b/networkx/algorithms/structuralholes.py
@@ -5,7 +5,7 @@ import networkx as nx
__all__ = ["constraint", "local_constraint", "effective_size"]
-@nx.dispatch("mutual_weight")
+@nx.dispatch
def mutual_weight(G, u, v, weight=None):
"""Returns the sum of the weights of the edge from `u` to `v` and
the edge from `v` to `u` in `G`.
diff --git a/networkx/algorithms/tournament.py b/networkx/algorithms/tournament.py
index 60d09f31..a57c3592 100644
--- a/networkx/algorithms/tournament.py
+++ b/networkx/algorithms/tournament.py
@@ -61,7 +61,7 @@ def index_satisfying(iterable, condition):
raise ValueError("iterable must be non-empty") from err
-@nx.dispatch("is_tournament")
+@nx.dispatch
@not_implemented_for("undirected")
@not_implemented_for("multigraph")
def is_tournament(G):
@@ -180,7 +180,7 @@ def random_tournament(n, seed=None):
return nx.DiGraph(edges)
-@nx.dispatch("score_sequence")
+@nx.dispatch
@not_implemented_for("undirected")
@not_implemented_for("multigraph")
def score_sequence(G):
@@ -210,7 +210,7 @@ def score_sequence(G):
return sorted(d for v, d in G.out_degree())
-@nx.dispatch("tournament_matrix")
+@nx.dispatch
@not_implemented_for("undirected")
@not_implemented_for("multigraph")
def tournament_matrix(G):
diff --git a/networkx/algorithms/triads.py b/networkx/algorithms/triads.py
index 0042f5c3..316643ea 100644
--- a/networkx/algorithms/triads.py
+++ b/networkx/algorithms/triads.py
@@ -275,7 +275,7 @@ def triadic_census(G, nodelist=None):
return census
-@nx.dispatch("is_triad")
+@nx.dispatch()
def is_triad(G):
"""Returns True if the graph G is a triad, else False.
diff --git a/networkx/classes/backends.py b/networkx/classes/backends.py
index c15edc80..30ddc5ec 100644
--- a/networkx/classes/backends.py
+++ b/networkx/classes/backends.py
@@ -43,20 +43,21 @@ the backend implementation.
Example pytest invocation:
NETWORKX_GRAPH_CONVERT=sparse pytest --pyargs networkx
-Any dispatchable algorithms which are not implemented by the backend
+Dispatchable algorithms which are not implemented by the backend
will cause a `pytest.xfail()`, giving some indication that not all
tests are working without causing an explicit failure.
-A special `mark_nx_tests(items)` function may be defined by the backend.
+A special `on_start_tests(items)` function may be defined by the backend.
It will be called with the list of NetworkX tests discovered. Each item
-is a pytest.Node object. If the backend does not support the test, it
-can be marked as xfail to indicate it is not being handled.
+is a pytest.Node object. If the backend does not support the test, that
+test can be marked as xfail.
"""
import os
import sys
import inspect
-from functools import wraps
+import functools
from importlib.metadata import entry_points
+from ..exception import NetworkXNotImplemented
__all__ = ["dispatch", "mark_tests"]
@@ -98,27 +99,98 @@ class PluginInfo:
plugins = PluginInfo()
-
-
-def dispatch(algo):
- def algorithm(func):
- @wraps(func)
- def wrapper(*args, **kwds):
- graph = args[0]
- if hasattr(graph, "__networkx_plugin__") and plugins:
- plugin_name = graph.__networkx_plugin__
- if plugin_name in plugins:
- backend = plugins[plugin_name].load()
- if hasattr(backend, algo):
- return getattr(backend, algo).__call__(*args, **kwds)
- return func(*args, **kwds)
-
- return wrapper
-
- return algorithm
-
-
-# Override `dispatch` for testing
+_registered_algorithms = {}
+
+
+def _register_algo(name, wrapped_func):
+ if name in _registered_algorithms:
+ raise KeyError(f"Algorithm already exists in dispatch registry: {name}")
+ _registered_algorithms[name] = wrapped_func
+
+
+def dispatch(func=None, *, name=None):
+ """Dispatches to a backend algorithm
+ when the first argument is a backend graph-like object.
+ """
+ # Allow any of the following decorator forms:
+ # - @dispatch
+ # - @dispatch()
+ # - @dispatch("override_name")
+ # - @dispatch(name="override_name")
+ if func is None:
+ if name is None:
+ return dispatch
+ return functools.partial(dispatch, name=name)
+ if isinstance(func, str):
+ return functools.partial(dispatch, name=func)
+ # If name not provided, use the name of the function
+ if name is None:
+ name = func.__name__
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwds):
+ graph = args[0]
+ if hasattr(graph, "__networkx_plugin__") and plugins:
+ plugin_name = graph.__networkx_plugin__
+ if plugin_name in plugins:
+ backend = plugins[plugin_name].load()
+ if hasattr(backend, name):
+ return getattr(backend, name).__call__(*args, **kwds)
+ else:
+ raise NetworkXNotImplemented(f"'{name}' not implemented by {plugin_name}")
+ return func(*args, **kwds)
+
+ _register_algo(name, wrapper)
+ return wrapper
+
+
+def test_override_dispatch(func=None, *, name=None):
+ """Auto-converts the first argument into the backend equivalent,
+ causing the dispatching mechanism to trigger for every
+ decorated algorithm."""
+ if func is None:
+ if name is None:
+ return test_override_dispatch
+ return functools.partial(test_override_dispatch, name=name)
+ if isinstance(func, str):
+ return functools.partial(test_override_dispatch, name=func)
+ # If name not provided, use the name of the function
+ if name is None:
+ name = func.__name__
+
+ sig = inspect.signature(func)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwds):
+ backend = plugins[plugin_name].load()
+ if not hasattr(backend, name):
+ pytest.xfail(f"'{name}' not implemented by {plugin_name}")
+ bound = sig.bind(*args, **kwds)
+ bound.apply_defaults()
+ graph, *args = args
+ # Convert graph into backend graph-like object
+ # Include the weight label, if provided to the algorithm
+ weight = None
+ if "weight" in bound.arguments:
+ weight = bound.arguments["weight"]
+ elif "data" in bound.arguments and "default" in bound.arguments:
+ # This case exists for several MultiGraph edge algorithms
+ if isinstance(bound.arguments["data"], str):
+ weight = bound.arguments["data"]
+ elif bound.arguments["data"]:
+ weight = "weight"
+ graph = backend.convert(graph, weight=weight)
+ return getattr(backend, name).__call__(graph, *args, **kwds)
+
+ _register_algo(name, wrapper)
+ return wrapper
+
+
+# Check for auto-convert testing
+# This allows existing NetworkX tests to be run against a backend
+# implementation without any changes to the testing code. The only
+# required change is to set an environment variable prior to running
+# pytest.
if os.environ.get("NETWORKX_GRAPH_CONVERT"):
plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
if plugin_name not in known_plugins:
@@ -133,40 +205,14 @@ if os.environ.get("NETWORKX_GRAPH_CONVERT"):
except ImportError:
raise ImportError(f"Missing pytest, which is required when using NETWORKX_GRAPH_CONVERT")
- def dispatch(algo):
- def algorithm(func):
- sig = inspect.signature(func)
-
- @wraps(func)
- def wrapper(*args, **kwds):
- backend = plugins[plugin_name].load()
- if not hasattr(backend, algo):
- pytest.xfail(f"'{algo}' not implemented by {plugin_name}")
- bound = sig.bind(*args, **kwds)
- bound.apply_defaults()
- graph, *args = args
- # Convert graph into backend graph-like object
- # Include the weight label, if provided to the algorithm
- weight = None
- if "weight" in bound.arguments:
- weight = bound.arguments["weight"]
- elif "data" in bound.arguments and "default" in bound.arguments:
- # This case exists for several MultiGraph edge algorithms
- if bound.arguments["data"]:
- weight = "weight"
- graph = backend.convert(graph, weight=weight)
- return getattr(backend, algo).__call__(graph, *args, **kwds)
-
- return wrapper
-
- return algorithm
+ # Override `dispatch` for testing
+ dispatch = test_override_dispatch
def mark_tests(items):
- # Allow backend to mark tests (skip or xfail) if they aren't
- # able to correctly handle them
+ """Allow backend to mark tests (skip or xfail) if they aren't able to correctly handle them"""
if os.environ.get("NETWORKX_GRAPH_CONVERT"):
plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
backend = plugins[plugin_name].load()
- if hasattr(backend, "mark_nx_tests"):
- getattr(backend, "mark_nx_tests")(items)
+ if hasattr(backend, "on_start_tests"):
+ getattr(backend, "on_start_tests")(items)