summaryrefslogtreecommitdiff
path: root/networkx/drawing/nx_pylab.py
diff options
context:
space:
mode:
Diffstat (limited to 'networkx/drawing/nx_pylab.py')
-rw-r--r--networkx/drawing/nx_pylab.py67
1 files changed, 17 insertions, 50 deletions
diff --git a/networkx/drawing/nx_pylab.py b/networkx/drawing/nx_pylab.py
index 580d4d03..6c5cb804 100644
--- a/networkx/drawing/nx_pylab.py
+++ b/networkx/drawing/nx_pylab.py
@@ -268,58 +268,25 @@ def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
draw_networkx_labels
draw_networkx_edge_labels
"""
- import matplotlib.pyplot as plt
-
- valid_node_kwds = (
- "nodelist",
- "node_size",
- "node_color",
- "node_shape",
- "alpha",
- "cmap",
- "vmin",
- "vmax",
- "ax",
- "linewidths",
- "edgecolors",
- "label",
- )
+ from inspect import signature
- valid_edge_kwds = (
- "edgelist",
- "width",
- "edge_color",
- "style",
- "alpha",
- "arrowstyle",
- "arrowsize",
- "edge_cmap",
- "edge_vmin",
- "edge_vmax",
- "ax",
- "label",
- "node_size",
- "nodelist",
- "node_shape",
- "connectionstyle",
- "min_source_margin",
- "min_target_margin",
- )
-
- valid_label_kwds = (
- "labels",
- "font_size",
- "font_color",
- "font_family",
- "font_weight",
- "alpha",
- "bbox",
- "ax",
- "horizontalalignment",
- "verticalalignment",
- )
+ import matplotlib.pyplot as plt
- valid_kwds = valid_node_kwds + valid_edge_kwds + valid_label_kwds
+ # Get all valid keywords by inspecting the signatures of draw_networkx_nodes,
+ # draw_networkx_edges, draw_networkx_labels
+
+ valid_node_kwds = signature(draw_networkx_nodes).parameters.keys()
+ valid_edge_kwds = signature(draw_networkx_edges).parameters.keys()
+ valid_label_kwds = signature(draw_networkx_labels).parameters.keys()
+
+ # Create a set with all valid keywords across the three functions and
+ # remove the arguments of this function (draw_networkx)
+ valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - {
+ "G",
+ "pos",
+ "arrows",
+ "with_labels",
+ }
if any([k not in valid_kwds for k in kwds]):
invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])