summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDimitrios Papageorgiou <dim_papag@windowslive.com>2021-10-28 03:47:41 +0300
committerGitHub <noreply@github.com>2021-10-27 17:47:41 -0700
commitc5e612f342660f1be7c09f31ce52535d37415296 (patch)
tree8f122dfb34498757d5ca1d1b92687283f2836cc5
parentcfb4b271166485fda8ebf82f00178f28602383bb (diff)
downloadnetworkx-c5e612f342660f1be7c09f31ce52535d37415296.tar.gz
Add option for arrowsize to be a list (#5154)
The arrowsize param to draw_networkx_edges can now be a list of values, one value per edge. Co-authored-by: Ross Barnowski <rossbar@berkeley.edu>
-rw-r--r--networkx/drawing/nx_pylab.py19
-rw-r--r--networkx/drawing/tests/test_pylab.py22
2 files changed, 37 insertions, 4 deletions
diff --git a/networkx/drawing/nx_pylab.py b/networkx/drawing/nx_pylab.py
index 836ff85e..16fc5933 100644
--- a/networkx/drawing/nx_pylab.py
+++ b/networkx/drawing/nx_pylab.py
@@ -155,10 +155,11 @@ def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
For directed graphs, choose the style of the arrowsheads.
See `matplotlib.patches.ArrowStyle` for more options.
- arrowsize : int (default=10)
+ arrowsize : int or list (default=10)
For directed graphs, choose the size of the arrow head's length and
- width. See `matplotlib.patches.FancyArrowPatch` for attribute
- `mutation_scale` for more info.
+ width. A list of values can be passed in to assign a different size for arrow head's length and width.
+ See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale`
+ for more info.
with_labels : bool (default=True)
Set to True to draw labels on the nodes.
@@ -750,7 +751,12 @@ def draw_networkx_edges(
# Draw arrows with `matplotlib.patches.FancyarrowPatch`
arrow_collection = []
- mutation_scale = arrowsize # scale factor of arrow head
+
+ if isinstance(arrowsize, list):
+ if len(arrowsize) != len(edge_pos):
+ raise ValueError("arrowsize should have the same length as edgelist")
+ else:
+ mutation_scale = arrowsize # scale factor of arrow head
base_connection_style = mpl.patches.ConnectionStyle(connectionstyle)
@@ -798,6 +804,11 @@ def draw_networkx_edges(
x2, y2 = dst
shrink_source = 0 # space from source to tail
shrink_target = 0 # space from head to target
+
+ if isinstance(arrowsize, list):
+ # Scale each factor of each arrow based on arrowsize list
+ mutation_scale = arrowsize[i]
+
if np.iterable(node_size): # many node sizes
source, target = edgelist[i][:2]
source_node_size = node_size[nodelist.index(source)]
diff --git a/networkx/drawing/tests/test_pylab.py b/networkx/drawing/tests/test_pylab.py
index f4a1a73e..4795613d 100644
--- a/networkx/drawing/tests/test_pylab.py
+++ b/networkx/drawing/tests/test_pylab.py
@@ -478,6 +478,28 @@ def test_error_invalid_kwds():
nx.draw(barbell, foo="bar")
+def test_draw_networkx_arrowsize_incorrect_size():
+ G = nx.DiGraph([(0, 1), (0, 2), (0, 3), (1, 3)])
+ arrowsize = [1, 2, 3]
+ with pytest.raises(
+ ValueError, match="arrowsize should have the same length as edgelist"
+ ):
+ nx.draw(G, arrowsize=arrowsize)
+
+
+@pytest.mark.parametrize("arrowsize", (30, [10, 20, 30]))
+def test_draw_edges_arrowsize(arrowsize):
+ G = nx.DiGraph([(0, 1), (0, 2), (1, 2)])
+ pos = {0: (0, 0), 1: (0, 1), 2: (1, 0)}
+ edges = nx.draw_networkx_edges(G, pos=pos, arrowsize=arrowsize)
+
+ arrowsize = itertools.repeat(arrowsize) if isinstance(arrowsize, int) else arrowsize
+
+ for fap, expected in zip(edges, arrowsize):
+ assert isinstance(fap, mpl.patches.FancyArrowPatch)
+ assert fap.get_mutation_scale() == expected
+
+
def test_np_edgelist():
# see issue #4129
np = pytest.importorskip("numpy")