From c2111a36841d5e43b6beae1fd667996d35646e30 Mon Sep 17 00:00:00 2001 From: Jim Kitchen Date: Tue, 11 Oct 2022 15:24:14 -0500 Subject: Allow dispatcher decorator without a name - Name is taken from the decorated function - Raise error if backend doesn't implement a decorated function which is called - Check for duplicate names for dispatching algorithms --- networkx/algorithms/boundary.py | 4 +- networkx/algorithms/cluster.py | 6 +- networkx/algorithms/core.py | 2 +- networkx/algorithms/cuts.py | 16 +-- networkx/algorithms/dag.py | 4 +- networkx/algorithms/dominating.py | 2 +- networkx/algorithms/isolate.py | 6 +- networkx/algorithms/link_analysis/pagerank_alg.py | 2 +- networkx/algorithms/reciprocity.py | 4 +- networkx/algorithms/regular.py | 4 +- networkx/algorithms/simple_paths.py | 2 +- networkx/algorithms/smetric.py | 2 +- networkx/algorithms/structuralholes.py | 2 +- networkx/algorithms/tournament.py | 6 +- networkx/algorithms/triads.py | 2 +- networkx/classes/backends.py | 160 ++++++++++++++-------- 16 files changed, 135 insertions(+), 89 deletions(-) diff --git a/networkx/algorithms/boundary.py b/networkx/algorithms/boundary.py index c74f47e9..068ce645 100644 --- a/networkx/algorithms/boundary.py +++ b/networkx/algorithms/boundary.py @@ -14,7 +14,7 @@ from itertools import chain __all__ = ["edge_boundary", "node_boundary"] -@nx.dispatch("edge_boundary") +@nx.dispatch def edge_boundary(G, nbunch1, nbunch2=None, data=False, keys=False, default=None): """Returns the edge boundary of `nbunch1`. @@ -91,7 +91,7 @@ def edge_boundary(G, nbunch1, nbunch2=None, data=False, keys=False, default=None ) -@nx.dispatch("node_boundary") +@nx.dispatch() def node_boundary(G, nbunch1, nbunch2=None): """Returns the node boundary of `nbunch1`. diff --git a/networkx/algorithms/cluster.py b/networkx/algorithms/cluster.py index 159aba78..83596c94 100644 --- a/networkx/algorithms/cluster.py +++ b/networkx/algorithms/cluster.py @@ -220,7 +220,7 @@ def _directed_weighted_triangles_and_degree_iter(G, nodes=None, weight="weight") yield (i, dtotal, dbidirectional, directed_triangles) -@nx.dispatch("average_clustering") +@nx.dispatch(name="average_clustering") def average_clustering(G, nodes=None, weight=None, count_zeros=True): r"""Compute the average clustering coefficient for the graph G. @@ -280,7 +280,7 @@ def average_clustering(G, nodes=None, weight=None, count_zeros=True): return sum(c) / len(c) -@nx.dispatch("clustering") +@nx.dispatch(name="clustering") def clustering(G, nodes=None, weight=None): r"""Compute the clustering coefficient for nodes. @@ -433,7 +433,7 @@ def transitivity(G): return 0 if triangles == 0 else triangles / contri -@nx.dispatch("square_clustering") +@nx.dispatch(name="square_clustering") def square_clustering(G, nodes=None): r"""Compute the squares clustering coefficient for nodes. diff --git a/networkx/algorithms/core.py b/networkx/algorithms/core.py index ac339554..faf1f0d0 100644 --- a/networkx/algorithms/core.py +++ b/networkx/algorithms/core.py @@ -378,7 +378,7 @@ def k_corona(G, k, core_number=None): return _core_subgraph(G, func, k, core_number) -@nx.dispatch("k_truss") +@nx.dispatch @not_implemented_for("directed") @not_implemented_for("multigraph") def k_truss(G, k): diff --git a/networkx/algorithms/cuts.py b/networkx/algorithms/cuts.py index 271ff309..2b4794ef 100644 --- a/networkx/algorithms/cuts.py +++ b/networkx/algorithms/cuts.py @@ -21,7 +21,7 @@ __all__ = [ # TODO STILL NEED TO UPDATE ALL THE DOCUMENTATION! -@nx.dispatch("cut_size") +@nx.dispatch def cut_size(G, S, T=None, weight=None): """Returns the size of the cut between two sets of nodes. @@ -84,7 +84,7 @@ def cut_size(G, S, T=None, weight=None): return sum(weight for u, v, weight in edges) -@nx.dispatch("volume") +@nx.dispatch def volume(G, S, weight=None): """Returns the volume of a set of nodes. @@ -127,7 +127,7 @@ def volume(G, S, weight=None): return sum(d for v, d in degree(S, weight=weight)) -@nx.dispatch("normalized_cut_size") +@nx.dispatch def normalized_cut_size(G, S, T=None, weight=None): """Returns the normalized size of the cut between two sets of nodes. @@ -180,7 +180,7 @@ def normalized_cut_size(G, S, T=None, weight=None): return num_cut_edges * ((1 / volume_S) + (1 / volume_T)) -@nx.dispatch("conductance") +@nx.dispatch def conductance(G, S, T=None, weight=None): """Returns the conductance of two sets of nodes. @@ -228,7 +228,7 @@ def conductance(G, S, T=None, weight=None): return num_cut_edges / min(volume_S, volume_T) -@nx.dispatch("edge_expansion") +@nx.dispatch def edge_expansion(G, S, T=None, weight=None): """Returns the edge expansion between two node sets. @@ -275,7 +275,7 @@ def edge_expansion(G, S, T=None, weight=None): return num_cut_edges / min(len(S), len(T)) -@nx.dispatch("mixing_expansion") +@nx.dispatch def mixing_expansion(G, S, T=None, weight=None): """Returns the mixing expansion between two node sets. @@ -323,7 +323,7 @@ def mixing_expansion(G, S, T=None, weight=None): # TODO What is the generalization to two arguments, S and T? Does the # denominator become `min(len(S), len(T))`? -@nx.dispatch("node_expansion") +@nx.dispatch def node_expansion(G, S): """Returns the node expansion of the set `S`. @@ -363,7 +363,7 @@ def node_expansion(G, S): # TODO What is the generalization to two arguments, S and T? Does the # denominator become `min(len(S), len(T))`? -@nx.dispatch("boundary_expansion") +@nx.dispatch def boundary_expansion(G, S): """Returns the boundary expansion of the set `S`. diff --git a/networkx/algorithms/dag.py b/networkx/algorithms/dag.py index f7692d1f..381f079c 100644 --- a/networkx/algorithms/dag.py +++ b/networkx/algorithms/dag.py @@ -36,7 +36,7 @@ __all__ = [ chaini = chain.from_iterable -@nx.dispatch("descendants") +@nx.dispatch def descendants(G, source): """Returns all nodes reachable from `source` in `G`. @@ -73,7 +73,7 @@ def descendants(G, source): return {child for parent, child in nx.bfs_edges(G, source)} -@nx.dispatch("ancestors") +@nx.dispatch def ancestors(G, source): """Returns all nodes having a path to `source` in `G`. diff --git a/networkx/algorithms/dominating.py b/networkx/algorithms/dominating.py index 3f2da907..f80454b7 100644 --- a/networkx/algorithms/dominating.py +++ b/networkx/algorithms/dominating.py @@ -64,7 +64,7 @@ def dominating_set(G, start_with=None): return dominating_set -@nx.dispatch("is_dominating_set") +@nx.dispatch def is_dominating_set(G, nbunch): """Checks if `nbunch` is a dominating set for `G`. diff --git a/networkx/algorithms/isolate.py b/networkx/algorithms/isolate.py index d59a46c4..0a077177 100644 --- a/networkx/algorithms/isolate.py +++ b/networkx/algorithms/isolate.py @@ -6,7 +6,7 @@ import networkx as nx __all__ = ["is_isolate", "isolates", "number_of_isolates"] -@nx.dispatch("is_isolate") +@nx.dispatch def is_isolate(G, n): """Determines whether a node is an isolate. @@ -39,7 +39,7 @@ def is_isolate(G, n): return G.degree(n) == 0 -@nx.dispatch("isolates") +@nx.dispatch def isolates(G): """Iterator over isolates in the graph. @@ -85,7 +85,7 @@ def isolates(G): return (n for n, d in G.degree() if d == 0) -@nx.dispatch("number_of_isolates") +@nx.dispatch def number_of_isolates(G): """Returns the number of isolates in the graph. diff --git a/networkx/algorithms/link_analysis/pagerank_alg.py b/networkx/algorithms/link_analysis/pagerank_alg.py index 346fce03..03054322 100644 --- a/networkx/algorithms/link_analysis/pagerank_alg.py +++ b/networkx/algorithms/link_analysis/pagerank_alg.py @@ -6,7 +6,7 @@ import networkx as nx __all__ = ["pagerank", "google_matrix"] -@nx.dispatch("pagerank") +@nx.dispatch def pagerank( G, alpha=0.85, diff --git a/networkx/algorithms/reciprocity.py b/networkx/algorithms/reciprocity.py index 9819de8b..660d922f 100644 --- a/networkx/algorithms/reciprocity.py +++ b/networkx/algorithms/reciprocity.py @@ -7,7 +7,7 @@ from ..utils import not_implemented_for __all__ = ["reciprocity", "overall_reciprocity"] -@nx.dispatch("reciprocity") +@nx.dispatch @not_implemented_for("undirected", "multigraph") def reciprocity(G, nodes=None): r"""Compute the reciprocity in a directed graph. @@ -75,7 +75,7 @@ def _reciprocity_iter(G, nodes): yield (node, reciprocity) -@nx.dispatch("overall_reciprocity") +@nx.dispatch @not_implemented_for("undirected", "multigraph") def overall_reciprocity(G): """Compute the reciprocity for the whole graph. diff --git a/networkx/algorithms/regular.py b/networkx/algorithms/regular.py index 89b98742..7c0b2709 100644 --- a/networkx/algorithms/regular.py +++ b/networkx/algorithms/regular.py @@ -5,7 +5,7 @@ from networkx.utils import not_implemented_for __all__ = ["is_regular", "is_k_regular", "k_factor"] -@nx.dispatch("is_regular") +@nx.dispatch def is_regular(G): """Determines whether the graph ``G`` is a regular graph. @@ -41,7 +41,7 @@ def is_regular(G): return in_regular and out_regular -@nx.dispatch("is_k_regular") +@nx.dispatch @not_implemented_for("directed") def is_k_regular(G, k): """Determines whether the graph ``G`` is a k-regular graph. diff --git a/networkx/algorithms/simple_paths.py b/networkx/algorithms/simple_paths.py index 0ddacabb..6284f961 100644 --- a/networkx/algorithms/simple_paths.py +++ b/networkx/algorithms/simple_paths.py @@ -13,7 +13,7 @@ __all__ = [ ] -@nx.dispatch("is_simple_path") +@nx.dispatch def is_simple_path(G, nodes): """Returns True if and only if `nodes` form a simple path in `G`. diff --git a/networkx/algorithms/smetric.py b/networkx/algorithms/smetric.py index a97bb8d9..a440aff1 100644 --- a/networkx/algorithms/smetric.py +++ b/networkx/algorithms/smetric.py @@ -3,7 +3,7 @@ import networkx as nx __all__ = ["s_metric"] -@nx.dispatch("s_metric") +@nx.dispatch def s_metric(G, normalized=True): """Returns the s-metric of graph. diff --git a/networkx/algorithms/structuralholes.py b/networkx/algorithms/structuralholes.py index 9762587e..417ba028 100644 --- a/networkx/algorithms/structuralholes.py +++ b/networkx/algorithms/structuralholes.py @@ -5,7 +5,7 @@ import networkx as nx __all__ = ["constraint", "local_constraint", "effective_size"] -@nx.dispatch("mutual_weight") +@nx.dispatch def mutual_weight(G, u, v, weight=None): """Returns the sum of the weights of the edge from `u` to `v` and the edge from `v` to `u` in `G`. diff --git a/networkx/algorithms/tournament.py b/networkx/algorithms/tournament.py index 60d09f31..a57c3592 100644 --- a/networkx/algorithms/tournament.py +++ b/networkx/algorithms/tournament.py @@ -61,7 +61,7 @@ def index_satisfying(iterable, condition): raise ValueError("iterable must be non-empty") from err -@nx.dispatch("is_tournament") +@nx.dispatch @not_implemented_for("undirected") @not_implemented_for("multigraph") def is_tournament(G): @@ -180,7 +180,7 @@ def random_tournament(n, seed=None): return nx.DiGraph(edges) -@nx.dispatch("score_sequence") +@nx.dispatch @not_implemented_for("undirected") @not_implemented_for("multigraph") def score_sequence(G): @@ -210,7 +210,7 @@ def score_sequence(G): return sorted(d for v, d in G.out_degree()) -@nx.dispatch("tournament_matrix") +@nx.dispatch @not_implemented_for("undirected") @not_implemented_for("multigraph") def tournament_matrix(G): diff --git a/networkx/algorithms/triads.py b/networkx/algorithms/triads.py index 0042f5c3..316643ea 100644 --- a/networkx/algorithms/triads.py +++ b/networkx/algorithms/triads.py @@ -275,7 +275,7 @@ def triadic_census(G, nodelist=None): return census -@nx.dispatch("is_triad") +@nx.dispatch() def is_triad(G): """Returns True if the graph G is a triad, else False. diff --git a/networkx/classes/backends.py b/networkx/classes/backends.py index c15edc80..30ddc5ec 100644 --- a/networkx/classes/backends.py +++ b/networkx/classes/backends.py @@ -43,20 +43,21 @@ the backend implementation. Example pytest invocation: NETWORKX_GRAPH_CONVERT=sparse pytest --pyargs networkx -Any dispatchable algorithms which are not implemented by the backend +Dispatchable algorithms which are not implemented by the backend will cause a `pytest.xfail()`, giving some indication that not all tests are working without causing an explicit failure. -A special `mark_nx_tests(items)` function may be defined by the backend. +A special `on_start_tests(items)` function may be defined by the backend. It will be called with the list of NetworkX tests discovered. Each item -is a pytest.Node object. If the backend does not support the test, it -can be marked as xfail to indicate it is not being handled. +is a pytest.Node object. If the backend does not support the test, that +test can be marked as xfail. """ import os import sys import inspect -from functools import wraps +import functools from importlib.metadata import entry_points +from ..exception import NetworkXNotImplemented __all__ = ["dispatch", "mark_tests"] @@ -98,27 +99,98 @@ class PluginInfo: plugins = PluginInfo() - - -def dispatch(algo): - def algorithm(func): - @wraps(func) - def wrapper(*args, **kwds): - graph = args[0] - if hasattr(graph, "__networkx_plugin__") and plugins: - plugin_name = graph.__networkx_plugin__ - if plugin_name in plugins: - backend = plugins[plugin_name].load() - if hasattr(backend, algo): - return getattr(backend, algo).__call__(*args, **kwds) - return func(*args, **kwds) - - return wrapper - - return algorithm - - -# Override `dispatch` for testing +_registered_algorithms = {} + + +def _register_algo(name, wrapped_func): + if name in _registered_algorithms: + raise KeyError(f"Algorithm already exists in dispatch registry: {name}") + _registered_algorithms[name] = wrapped_func + + +def dispatch(func=None, *, name=None): + """Dispatches to a backend algorithm + when the first argument is a backend graph-like object. + """ + # Allow any of the following decorator forms: + # - @dispatch + # - @dispatch() + # - @dispatch("override_name") + # - @dispatch(name="override_name") + if func is None: + if name is None: + return dispatch + return functools.partial(dispatch, name=name) + if isinstance(func, str): + return functools.partial(dispatch, name=func) + # If name not provided, use the name of the function + if name is None: + name = func.__name__ + + @functools.wraps(func) + def wrapper(*args, **kwds): + graph = args[0] + if hasattr(graph, "__networkx_plugin__") and plugins: + plugin_name = graph.__networkx_plugin__ + if plugin_name in plugins: + backend = plugins[plugin_name].load() + if hasattr(backend, name): + return getattr(backend, name).__call__(*args, **kwds) + else: + raise NetworkXNotImplemented(f"'{name}' not implemented by {plugin_name}") + return func(*args, **kwds) + + _register_algo(name, wrapper) + return wrapper + + +def test_override_dispatch(func=None, *, name=None): + """Auto-converts the first argument into the backend equivalent, + causing the dispatching mechanism to trigger for every + decorated algorithm.""" + if func is None: + if name is None: + return test_override_dispatch + return functools.partial(test_override_dispatch, name=name) + if isinstance(func, str): + return functools.partial(test_override_dispatch, name=func) + # If name not provided, use the name of the function + if name is None: + name = func.__name__ + + sig = inspect.signature(func) + + @functools.wraps(func) + def wrapper(*args, **kwds): + backend = plugins[plugin_name].load() + if not hasattr(backend, name): + pytest.xfail(f"'{name}' not implemented by {plugin_name}") + bound = sig.bind(*args, **kwds) + bound.apply_defaults() + graph, *args = args + # Convert graph into backend graph-like object + # Include the weight label, if provided to the algorithm + weight = None + if "weight" in bound.arguments: + weight = bound.arguments["weight"] + elif "data" in bound.arguments and "default" in bound.arguments: + # This case exists for several MultiGraph edge algorithms + if isinstance(bound.arguments["data"], str): + weight = bound.arguments["data"] + elif bound.arguments["data"]: + weight = "weight" + graph = backend.convert(graph, weight=weight) + return getattr(backend, name).__call__(graph, *args, **kwds) + + _register_algo(name, wrapper) + return wrapper + + +# Check for auto-convert testing +# This allows existing NetworkX tests to be run against a backend +# implementation without any changes to the testing code. The only +# required change is to set an environment variable prior to running +# pytest. if os.environ.get("NETWORKX_GRAPH_CONVERT"): plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"] if plugin_name not in known_plugins: @@ -133,40 +205,14 @@ if os.environ.get("NETWORKX_GRAPH_CONVERT"): except ImportError: raise ImportError(f"Missing pytest, which is required when using NETWORKX_GRAPH_CONVERT") - def dispatch(algo): - def algorithm(func): - sig = inspect.signature(func) - - @wraps(func) - def wrapper(*args, **kwds): - backend = plugins[plugin_name].load() - if not hasattr(backend, algo): - pytest.xfail(f"'{algo}' not implemented by {plugin_name}") - bound = sig.bind(*args, **kwds) - bound.apply_defaults() - graph, *args = args - # Convert graph into backend graph-like object - # Include the weight label, if provided to the algorithm - weight = None - if "weight" in bound.arguments: - weight = bound.arguments["weight"] - elif "data" in bound.arguments and "default" in bound.arguments: - # This case exists for several MultiGraph edge algorithms - if bound.arguments["data"]: - weight = "weight" - graph = backend.convert(graph, weight=weight) - return getattr(backend, algo).__call__(graph, *args, **kwds) - - return wrapper - - return algorithm + # Override `dispatch` for testing + dispatch = test_override_dispatch def mark_tests(items): - # Allow backend to mark tests (skip or xfail) if they aren't - # able to correctly handle them + """Allow backend to mark tests (skip or xfail) if they aren't able to correctly handle them""" if os.environ.get("NETWORKX_GRAPH_CONVERT"): plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"] backend = plugins[plugin_name].load() - if hasattr(backend, "mark_nx_tests"): - getattr(backend, "mark_nx_tests")(items) + if hasattr(backend, "on_start_tests"): + getattr(backend, "on_start_tests")(items) -- cgit v1.2.1