summaryrefslogtreecommitdiff
path: root/networkx/classes/backends.py
diff options
context:
space:
mode:
Diffstat (limited to 'networkx/classes/backends.py')
-rw-r--r--networkx/classes/backends.py160
1 files changed, 103 insertions, 57 deletions
diff --git a/networkx/classes/backends.py b/networkx/classes/backends.py
index c15edc80..30ddc5ec 100644
--- a/networkx/classes/backends.py
+++ b/networkx/classes/backends.py
@@ -43,20 +43,21 @@ the backend implementation.
Example pytest invocation:
NETWORKX_GRAPH_CONVERT=sparse pytest --pyargs networkx
-Any dispatchable algorithms which are not implemented by the backend
+Dispatchable algorithms which are not implemented by the backend
will cause a `pytest.xfail()`, giving some indication that not all
tests are working without causing an explicit failure.
-A special `mark_nx_tests(items)` function may be defined by the backend.
+A special `on_start_tests(items)` function may be defined by the backend.
It will be called with the list of NetworkX tests discovered. Each item
-is a pytest.Node object. If the backend does not support the test, it
-can be marked as xfail to indicate it is not being handled.
+is a pytest.Node object. If the backend does not support the test, that
+test can be marked as xfail.
"""
import os
import sys
import inspect
-from functools import wraps
+import functools
from importlib.metadata import entry_points
+from ..exception import NetworkXNotImplemented
__all__ = ["dispatch", "mark_tests"]
@@ -98,27 +99,98 @@ class PluginInfo:
plugins = PluginInfo()
-
-
-def dispatch(algo):
- def algorithm(func):
- @wraps(func)
- def wrapper(*args, **kwds):
- graph = args[0]
- if hasattr(graph, "__networkx_plugin__") and plugins:
- plugin_name = graph.__networkx_plugin__
- if plugin_name in plugins:
- backend = plugins[plugin_name].load()
- if hasattr(backend, algo):
- return getattr(backend, algo).__call__(*args, **kwds)
- return func(*args, **kwds)
-
- return wrapper
-
- return algorithm
-
-
-# Override `dispatch` for testing
+_registered_algorithms = {}
+
+
+def _register_algo(name, wrapped_func):
+ if name in _registered_algorithms:
+ raise KeyError(f"Algorithm already exists in dispatch registry: {name}")
+ _registered_algorithms[name] = wrapped_func
+
+
+def dispatch(func=None, *, name=None):
+ """Dispatches to a backend algorithm
+ when the first argument is a backend graph-like object.
+ """
+ # Allow any of the following decorator forms:
+ # - @dispatch
+ # - @dispatch()
+ # - @dispatch("override_name")
+ # - @dispatch(name="override_name")
+ if func is None:
+ if name is None:
+ return dispatch
+ return functools.partial(dispatch, name=name)
+ if isinstance(func, str):
+ return functools.partial(dispatch, name=func)
+ # If name not provided, use the name of the function
+ if name is None:
+ name = func.__name__
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwds):
+ graph = args[0]
+ if hasattr(graph, "__networkx_plugin__") and plugins:
+ plugin_name = graph.__networkx_plugin__
+ if plugin_name in plugins:
+ backend = plugins[plugin_name].load()
+ if hasattr(backend, name):
+ return getattr(backend, name).__call__(*args, **kwds)
+ else:
+ raise NetworkXNotImplemented(f"'{name}' not implemented by {plugin_name}")
+ return func(*args, **kwds)
+
+ _register_algo(name, wrapper)
+ return wrapper
+
+
+def test_override_dispatch(func=None, *, name=None):
+ """Auto-converts the first argument into the backend equivalent,
+ causing the dispatching mechanism to trigger for every
+ decorated algorithm."""
+ if func is None:
+ if name is None:
+ return test_override_dispatch
+ return functools.partial(test_override_dispatch, name=name)
+ if isinstance(func, str):
+ return functools.partial(test_override_dispatch, name=func)
+ # If name not provided, use the name of the function
+ if name is None:
+ name = func.__name__
+
+ sig = inspect.signature(func)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwds):
+ backend = plugins[plugin_name].load()
+ if not hasattr(backend, name):
+ pytest.xfail(f"'{name}' not implemented by {plugin_name}")
+ bound = sig.bind(*args, **kwds)
+ bound.apply_defaults()
+ graph, *args = args
+ # Convert graph into backend graph-like object
+ # Include the weight label, if provided to the algorithm
+ weight = None
+ if "weight" in bound.arguments:
+ weight = bound.arguments["weight"]
+ elif "data" in bound.arguments and "default" in bound.arguments:
+ # This case exists for several MultiGraph edge algorithms
+ if isinstance(bound.arguments["data"], str):
+ weight = bound.arguments["data"]
+ elif bound.arguments["data"]:
+ weight = "weight"
+ graph = backend.convert(graph, weight=weight)
+ return getattr(backend, name).__call__(graph, *args, **kwds)
+
+ _register_algo(name, wrapper)
+ return wrapper
+
+
+# Check for auto-convert testing
+# This allows existing NetworkX tests to be run against a backend
+# implementation without any changes to the testing code. The only
+# required change is to set an environment variable prior to running
+# pytest.
if os.environ.get("NETWORKX_GRAPH_CONVERT"):
plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
if plugin_name not in known_plugins:
@@ -133,40 +205,14 @@ if os.environ.get("NETWORKX_GRAPH_CONVERT"):
except ImportError:
raise ImportError(f"Missing pytest, which is required when using NETWORKX_GRAPH_CONVERT")
- def dispatch(algo):
- def algorithm(func):
- sig = inspect.signature(func)
-
- @wraps(func)
- def wrapper(*args, **kwds):
- backend = plugins[plugin_name].load()
- if not hasattr(backend, algo):
- pytest.xfail(f"'{algo}' not implemented by {plugin_name}")
- bound = sig.bind(*args, **kwds)
- bound.apply_defaults()
- graph, *args = args
- # Convert graph into backend graph-like object
- # Include the weight label, if provided to the algorithm
- weight = None
- if "weight" in bound.arguments:
- weight = bound.arguments["weight"]
- elif "data" in bound.arguments and "default" in bound.arguments:
- # This case exists for several MultiGraph edge algorithms
- if bound.arguments["data"]:
- weight = "weight"
- graph = backend.convert(graph, weight=weight)
- return getattr(backend, algo).__call__(graph, *args, **kwds)
-
- return wrapper
-
- return algorithm
+ # Override `dispatch` for testing
+ dispatch = test_override_dispatch
def mark_tests(items):
- # Allow backend to mark tests (skip or xfail) if they aren't
- # able to correctly handle them
+ """Allow backend to mark tests (skip or xfail) if they aren't able to correctly handle them"""
if os.environ.get("NETWORKX_GRAPH_CONVERT"):
plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
backend = plugins[plugin_name].load()
- if hasattr(backend, "mark_nx_tests"):
- getattr(backend, "mark_nx_tests")(items)
+ if hasattr(backend, "on_start_tests"):
+ getattr(backend, "on_start_tests")(items)