diff options
-rw-r--r-- | networkx/drawing/nx_pylab.py | 13 | ||||
-rw-r--r-- | networkx/drawing/tests/test_pylab.py | 29 |
2 files changed, 37 insertions, 5 deletions
diff --git a/networkx/drawing/nx_pylab.py b/networkx/drawing/nx_pylab.py index 0e392e2b..56e89a71 100644 --- a/networkx/drawing/nx_pylab.py +++ b/networkx/drawing/nx_pylab.py @@ -697,6 +697,7 @@ def draw_networkx_edges( # FancyArrowPatch handles color=None different from LineCollection if edge_color is None: edge_color = "k" + edgelist_tuple = list(map(tuple, edgelist)) # set edge positions edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) @@ -797,7 +798,7 @@ def draw_networkx_edges( # FancyArrowPatch doesn't handle color strings arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha) - for i, (src, dst) in enumerate(edge_pos): + for i, (src, dst) in zip(fancy_edges_indices, edge_pos): x1, y1 = src x2, y2 = dst shrink_source = 0 # space from source to tail @@ -822,7 +823,7 @@ def draw_networkx_edges( if shrink_target < min_target_margin: shrink_target = min_target_margin - if len(arrow_colors) == len(edge_pos): + if len(arrow_colors) > i: arrow_color = arrow_colors[i] elif len(arrow_colors) == 1: arrow_color = arrow_colors[0] @@ -830,7 +831,7 @@ def draw_networkx_edges( arrow_color = arrow_colors[i % len(arrow_colors)] if np.iterable(width): - if len(width) == len(edge_pos): + if len(width) > i: line_width = width[i] else: line_width = width[i % len(width)] @@ -842,7 +843,7 @@ def draw_networkx_edges( and not isinstance(style, str) and not isinstance(style, tuple) ): - if len(style) == len(edge_pos): + if len(style) > i: linestyle = style[i] else: # Cycle through styles linestyle = style[i % len(style)] @@ -882,10 +883,14 @@ def draw_networkx_edges( # Make sure selfloop edges are also drawn selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist] if selfloops_to_draw: + fancy_edges_indices = [ + edgelist_tuple.index(loop) for loop in selfloops_to_draw + ] edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in selfloops_to_draw]) arrowstyle = "-" _draw_networkx_edges_fancy_arrow_patch() else: + fancy_edges_indices = range(len(edgelist)) edge_viz_obj = _draw_networkx_edges_fancy_arrow_patch() # update view after drawing diff --git a/networkx/drawing/tests/test_pylab.py b/networkx/drawing/tests/test_pylab.py index 54ea6e0c..c5189fd6 100644 --- a/networkx/drawing/tests/test_pylab.py +++ b/networkx/drawing/tests/test_pylab.py @@ -5,6 +5,7 @@ import itertools import pytest mpl = pytest.importorskip("matplotlib") +np = pytest.importorskip("numpy") mpl.use("PS") plt = pytest.importorskip("matplotlib.pyplot") plt.rcParams["text.usetex"] = False @@ -528,7 +529,6 @@ def test_draw_edges_arrowsize(arrowsize): def test_np_edgelist(): # see issue #4129 - np = pytest.importorskip("numpy") nx.draw_networkx(barbell, edgelist=np.array([(0, 2), (0, 3)])) @@ -724,3 +724,30 @@ def test_draw_networkx_edge_label_empty_dict(): G = nx.path_graph(3) pos = {n: (n, n) for n in G.nodes} assert nx.draw_networkx_edge_labels(G, pos, edge_labels={}) == {} + + +def test_draw_networkx_edges_undirected_selfloop_colors(): + """When an edgelist is supplied along with a sequence of colors, check that + the self-loops have the correct colors.""" + fig, ax = plt.subplots() + # Edge list and corresponding colors + edgelist = [(1, 3), (1, 2), (2, 3), (1, 1), (3, 3), (2, 2)] + edge_colors = ["pink", "cyan", "black", "red", "blue", "green"] + + G = nx.Graph(edgelist) + pos = {n: (n, n) for n in G.nodes} + nx.draw_networkx_edges(G, pos, ax=ax, edgelist=edgelist, edge_color=edge_colors) + + # Verify that there are three fancy arrow patches (1 per self loop) + assert len(ax.patches) == 3 + + # These are points that should be contained in the self loops. For example, + # sl_points[0] will be (1, 1.1), which is inside the "path" of the first + # self-loop but outside the others + sl_points = np.array(edgelist[-3:]) + np.array([0, 0.1]) + + # Check that the mapping between self-loop locations and their colors is + # correct + for fap, clr, slp in zip(ax.patches, edge_colors[-3:], sl_points): + assert fap.get_path().contains_point(slp) + assert mpl.colors.same_color(fap.get_edgecolor(), clr) |