diff options
Diffstat (limited to 'networkx/classes/backends.py')
-rw-r--r-- | networkx/classes/backends.py | 160 |
1 files changed, 103 insertions, 57 deletions
diff --git a/networkx/classes/backends.py b/networkx/classes/backends.py index c15edc80..30ddc5ec 100644 --- a/networkx/classes/backends.py +++ b/networkx/classes/backends.py @@ -43,20 +43,21 @@ the backend implementation. Example pytest invocation: NETWORKX_GRAPH_CONVERT=sparse pytest --pyargs networkx -Any dispatchable algorithms which are not implemented by the backend +Dispatchable algorithms which are not implemented by the backend will cause a `pytest.xfail()`, giving some indication that not all tests are working without causing an explicit failure. -A special `mark_nx_tests(items)` function may be defined by the backend. +A special `on_start_tests(items)` function may be defined by the backend. It will be called with the list of NetworkX tests discovered. Each item -is a pytest.Node object. If the backend does not support the test, it -can be marked as xfail to indicate it is not being handled. +is a pytest.Node object. If the backend does not support the test, that +test can be marked as xfail. """ import os import sys import inspect -from functools import wraps +import functools from importlib.metadata import entry_points +from ..exception import NetworkXNotImplemented __all__ = ["dispatch", "mark_tests"] @@ -98,27 +99,98 @@ class PluginInfo: plugins = PluginInfo() - - -def dispatch(algo): - def algorithm(func): - @wraps(func) - def wrapper(*args, **kwds): - graph = args[0] - if hasattr(graph, "__networkx_plugin__") and plugins: - plugin_name = graph.__networkx_plugin__ - if plugin_name in plugins: - backend = plugins[plugin_name].load() - if hasattr(backend, algo): - return getattr(backend, algo).__call__(*args, **kwds) - return func(*args, **kwds) - - return wrapper - - return algorithm - - -# Override `dispatch` for testing +_registered_algorithms = {} + + +def _register_algo(name, wrapped_func): + if name in _registered_algorithms: + raise KeyError(f"Algorithm already exists in dispatch registry: {name}") + _registered_algorithms[name] = wrapped_func + + +def dispatch(func=None, *, name=None): + """Dispatches to a backend algorithm + when the first argument is a backend graph-like object. + """ + # Allow any of the following decorator forms: + # - @dispatch + # - @dispatch() + # - @dispatch("override_name") + # - @dispatch(name="override_name") + if func is None: + if name is None: + return dispatch + return functools.partial(dispatch, name=name) + if isinstance(func, str): + return functools.partial(dispatch, name=func) + # If name not provided, use the name of the function + if name is None: + name = func.__name__ + + @functools.wraps(func) + def wrapper(*args, **kwds): + graph = args[0] + if hasattr(graph, "__networkx_plugin__") and plugins: + plugin_name = graph.__networkx_plugin__ + if plugin_name in plugins: + backend = plugins[plugin_name].load() + if hasattr(backend, name): + return getattr(backend, name).__call__(*args, **kwds) + else: + raise NetworkXNotImplemented(f"'{name}' not implemented by {plugin_name}") + return func(*args, **kwds) + + _register_algo(name, wrapper) + return wrapper + + +def test_override_dispatch(func=None, *, name=None): + """Auto-converts the first argument into the backend equivalent, + causing the dispatching mechanism to trigger for every + decorated algorithm.""" + if func is None: + if name is None: + return test_override_dispatch + return functools.partial(test_override_dispatch, name=name) + if isinstance(func, str): + return functools.partial(test_override_dispatch, name=func) + # If name not provided, use the name of the function + if name is None: + name = func.__name__ + + sig = inspect.signature(func) + + @functools.wraps(func) + def wrapper(*args, **kwds): + backend = plugins[plugin_name].load() + if not hasattr(backend, name): + pytest.xfail(f"'{name}' not implemented by {plugin_name}") + bound = sig.bind(*args, **kwds) + bound.apply_defaults() + graph, *args = args + # Convert graph into backend graph-like object + # Include the weight label, if provided to the algorithm + weight = None + if "weight" in bound.arguments: + weight = bound.arguments["weight"] + elif "data" in bound.arguments and "default" in bound.arguments: + # This case exists for several MultiGraph edge algorithms + if isinstance(bound.arguments["data"], str): + weight = bound.arguments["data"] + elif bound.arguments["data"]: + weight = "weight" + graph = backend.convert(graph, weight=weight) + return getattr(backend, name).__call__(graph, *args, **kwds) + + _register_algo(name, wrapper) + return wrapper + + +# Check for auto-convert testing +# This allows existing NetworkX tests to be run against a backend +# implementation without any changes to the testing code. The only +# required change is to set an environment variable prior to running +# pytest. if os.environ.get("NETWORKX_GRAPH_CONVERT"): plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"] if plugin_name not in known_plugins: @@ -133,40 +205,14 @@ if os.environ.get("NETWORKX_GRAPH_CONVERT"): except ImportError: raise ImportError(f"Missing pytest, which is required when using NETWORKX_GRAPH_CONVERT") - def dispatch(algo): - def algorithm(func): - sig = inspect.signature(func) - - @wraps(func) - def wrapper(*args, **kwds): - backend = plugins[plugin_name].load() - if not hasattr(backend, algo): - pytest.xfail(f"'{algo}' not implemented by {plugin_name}") - bound = sig.bind(*args, **kwds) - bound.apply_defaults() - graph, *args = args - # Convert graph into backend graph-like object - # Include the weight label, if provided to the algorithm - weight = None - if "weight" in bound.arguments: - weight = bound.arguments["weight"] - elif "data" in bound.arguments and "default" in bound.arguments: - # This case exists for several MultiGraph edge algorithms - if bound.arguments["data"]: - weight = "weight" - graph = backend.convert(graph, weight=weight) - return getattr(backend, algo).__call__(graph, *args, **kwds) - - return wrapper - - return algorithm + # Override `dispatch` for testing + dispatch = test_override_dispatch def mark_tests(items): - # Allow backend to mark tests (skip or xfail) if they aren't - # able to correctly handle them + """Allow backend to mark tests (skip or xfail) if they aren't able to correctly handle them""" if os.environ.get("NETWORKX_GRAPH_CONVERT"): plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"] backend = plugins[plugin_name].load() - if hasattr(backend, "mark_nx_tests"): - getattr(backend, "mark_nx_tests")(items) + if hasattr(backend, "on_start_tests"): + getattr(backend, "on_start_tests")(items) |