summaryrefslogtreecommitdiff
path: root/networkx/testing/utils.py
blob: 68faf230560ebd9f4899b05b2c5b21c34022260f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
__all__ = [
    "assert_nodes_equal",
    "assert_edges_equal",
    "assert_graphs_equal",
    "almost_equal",
]


def almost_equal(x, y, places=7):
    import warnings

    warnings.warn(
        (
            "`almost_equal` is deprecated and will be removed in version 3.0.\n"
            "Use `pytest.approx` instead.\n"
        ),
        DeprecationWarning,
    )
    return round(abs(x - y), places) == 0


def assert_nodes_equal(nodes1, nodes2):
    # Assumes iterables of nodes, or (node,datadict) tuples
    nlist1 = list(nodes1)
    nlist2 = list(nodes2)
    try:
        d1 = dict(nlist1)
        d2 = dict(nlist2)
    except (ValueError, TypeError):
        d1 = dict.fromkeys(nlist1)
        d2 = dict.fromkeys(nlist2)
    assert d1 == d2


def assert_edges_equal(edges1, edges2):
    # Assumes iterables with u,v nodes as
    # edge tuples (u,v), or
    # edge tuples with data dicts (u,v,d), or
    # edge tuples with keys and data dicts (u,v,k, d)
    from collections import defaultdict

    d1 = defaultdict(dict)
    d2 = defaultdict(dict)
    c1 = 0
    for c1, e in enumerate(edges1):
        u, v = e[0], e[1]
        data = [e[2:]]
        if v in d1[u]:
            data = d1[u][v] + data
        d1[u][v] = data
        d1[v][u] = data
    c2 = 0
    for c2, e in enumerate(edges2):
        u, v = e[0], e[1]
        data = [e[2:]]
        if v in d2[u]:
            data = d2[u][v] + data
        d2[u][v] = data
        d2[v][u] = data
    assert c1 == c2
    # can check one direction because lengths are the same.
    for n, nbrdict in d1.items():
        for nbr, datalist in nbrdict.items():
            assert n in d2
            assert nbr in d2[n]
            d2datalist = d2[n][nbr]
            for data in datalist:
                assert datalist.count(data) == d2datalist.count(data)


def assert_graphs_equal(graph1, graph2):
    assert graph1.adj == graph2.adj
    assert graph1.nodes == graph2.nodes
    assert graph1.graph == graph2.graph