summaryrefslogtreecommitdiff
path: root/networkx/classes/backends.py
blob: c15edc800dfa577b6eda1ebabffe3fd87ab6f285 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Code to support various backends in a plugin dispatch architecture.

Create a Dispatcher
-------------------

To be a valid plugin, a package must register an entry_point
of `networkx.plugins` with a known key pointing to the handler.

For example,
```
entry_points={'networkx.plugins': 'sparse = networkx_plugin_sparse'}
```

The plugin must create a Graph-like object which contains an attribute
`__networkx_plugin__` with a value of the known key.

Continuing the example above:
```
class WrappedSparse:
    __networkx_plugin__ = "sparse"
    ...
```

When a dispatchable networkx algorithm encounters a Graph-like object
with a `__networkx_plugin__` attribute, it will look for the associated
dispatch object in the entry_points, load it, and dispatch the work ot it.


Testing
-------
To assist in validating the backend algorithm implementations, if an
environment variable `NETWORKX_GRAPH_CONVERT` is set to one of the known
plugin keys, the dispatch machinery will automatically convert regular
networkx Graphs and DiGraphs to the backend equivalent by calling
`<backend dispatcher>.convert(G, weight=weight)`.

By defining a `convert` method and setting the environment variable,
networkx will automatically route tests on dispatchable algorithms
to the backend, allowing the full networkx test suite to be run against
the backend implementation.

Example pytest invocation:
NETWORKX_GRAPH_CONVERT=sparse pytest --pyargs networkx

Any dispatchable algorithms which are not implemented by the backend
will cause a `pytest.xfail()`, giving some indication that not all
tests are working without causing an explicit failure.

A special `mark_nx_tests(items)` function may be defined by the backend.
It will be called with the list of NetworkX tests discovered. Each item
is a pytest.Node object. If the backend does not support the test, it
can be marked as xfail to indicate it is not being handled.
"""
import os
import sys
import inspect
from functools import wraps
from importlib.metadata import entry_points


__all__ = ["dispatch", "mark_tests"]


known_plugins = [
    "sparse",  # scipy.sparse
    "graphblas",  # python-graphblas
    "cugraph",  # cugraph
]


class PluginInfo:
    """Lazily loaded entry_points plugin information"""
    def __init__(self):
        self._items = None

    def __bool__(self):
        return len(self.items) > 0

    @property
    def items(self):
        if self._items is None:
            if sys.version_info < (3, 10):
                self._items = entry_points()["networkx.plugins"]
            else:
                self._items = entry_points(group="networkx.plugins")
        return self._items

    def __contains__(self, name):
        if sys.version_info < (3, 10):
            return len([ep for ep in self.items if ep.name == name]) > 0
        return name in self.items.names

    def __getitem__(self, name):
        if sys.version_info < (3, 10):
            return [ep for ep in self.items if ep.name == name][0]
        return self.items[name]


plugins = PluginInfo()


def dispatch(algo):
    def algorithm(func):
        @wraps(func)
        def wrapper(*args, **kwds):
            graph = args[0]
            if hasattr(graph, "__networkx_plugin__") and plugins:
                plugin_name = graph.__networkx_plugin__
                if plugin_name in plugins:
                    backend = plugins[plugin_name].load()
                    if hasattr(backend, algo):
                        return getattr(backend, algo).__call__(*args, **kwds)
            return func(*args, **kwds)

        return wrapper

    return algorithm


# Override `dispatch` for testing
if os.environ.get("NETWORKX_GRAPH_CONVERT"):
    plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
    if plugin_name not in known_plugins:
        raise Exception(f"{plugin_name} is not a known plugin; must be one of {known_plugins}")
    if not plugins:
        raise Exception("No registered networkx.plugins entry_points")
    if plugin_name not in plugins:
        raise Exception(f"No registered networkx.plugins entry_point named {plugin_name}")

    try:
        import pytest
    except ImportError:
        raise ImportError(f"Missing pytest, which is required when using NETWORKX_GRAPH_CONVERT")

    def dispatch(algo):
        def algorithm(func):
            sig = inspect.signature(func)

            @wraps(func)
            def wrapper(*args, **kwds):
                backend = plugins[plugin_name].load()
                if not hasattr(backend, algo):
                    pytest.xfail(f"'{algo}' not implemented by {plugin_name}")
                bound = sig.bind(*args, **kwds)
                bound.apply_defaults()
                graph, *args = args
                # Convert graph into backend graph-like object
                #   Include the weight label, if provided to the algorithm
                weight = None
                if "weight" in bound.arguments:
                    weight = bound.arguments["weight"]
                elif "data" in bound.arguments and "default" in bound.arguments:
                    # This case exists for several MultiGraph edge algorithms
                    if bound.arguments["data"]:
                        weight = "weight"
                graph = backend.convert(graph, weight=weight)
                return getattr(backend, algo).__call__(graph, *args, **kwds)

            return wrapper

        return algorithm


def mark_tests(items):
    # Allow backend to mark tests (skip or xfail) if they aren't
    # able to correctly handle them
    if os.environ.get("NETWORKX_GRAPH_CONVERT"):
        plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
        backend = plugins[plugin_name].load()
        if hasattr(backend, "mark_nx_tests"):
            getattr(backend, "mark_nx_tests")(items)