summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMridul Seth <seth.mridul@gmail.com>2022-05-26 00:03:39 +0400
committerGitHub <noreply@github.com>2022-05-25 13:03:39 -0700
commit40339c8748f6737d30f29473c86c52800e2c6b87 (patch)
tree25591584d9aadfbb517a06b47622801364421265
parent58b63cb57cd1747c23611ee0b46991a5be2db751 (diff)
downloadnetworkx-40339c8748f6737d30f29473c86c52800e2c6b87.tar.gz
Extract valid kwds from the function signature for draw_networkx_* (#5660)
This makes sure we don't miss the cases where we add a new argument to a draw_networkx_* and not update the valid kwargs in draw_networkx.
-rw-r--r--networkx/drawing/nx_pylab.py67
1 files changed, 17 insertions, 50 deletions
diff --git a/networkx/drawing/nx_pylab.py b/networkx/drawing/nx_pylab.py
index 580d4d03..6c5cb804 100644
--- a/networkx/drawing/nx_pylab.py
+++ b/networkx/drawing/nx_pylab.py
@@ -268,58 +268,25 @@ def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
draw_networkx_labels
draw_networkx_edge_labels
"""
- import matplotlib.pyplot as plt
-
- valid_node_kwds = (
- "nodelist",
- "node_size",
- "node_color",
- "node_shape",
- "alpha",
- "cmap",
- "vmin",
- "vmax",
- "ax",
- "linewidths",
- "edgecolors",
- "label",
- )
+ from inspect import signature
- valid_edge_kwds = (
- "edgelist",
- "width",
- "edge_color",
- "style",
- "alpha",
- "arrowstyle",
- "arrowsize",
- "edge_cmap",
- "edge_vmin",
- "edge_vmax",
- "ax",
- "label",
- "node_size",
- "nodelist",
- "node_shape",
- "connectionstyle",
- "min_source_margin",
- "min_target_margin",
- )
-
- valid_label_kwds = (
- "labels",
- "font_size",
- "font_color",
- "font_family",
- "font_weight",
- "alpha",
- "bbox",
- "ax",
- "horizontalalignment",
- "verticalalignment",
- )
+ import matplotlib.pyplot as plt
- valid_kwds = valid_node_kwds + valid_edge_kwds + valid_label_kwds
+ # Get all valid keywords by inspecting the signatures of draw_networkx_nodes,
+ # draw_networkx_edges, draw_networkx_labels
+
+ valid_node_kwds = signature(draw_networkx_nodes).parameters.keys()
+ valid_edge_kwds = signature(draw_networkx_edges).parameters.keys()
+ valid_label_kwds = signature(draw_networkx_labels).parameters.keys()
+
+ # Create a set with all valid keywords across the three functions and
+ # remove the arguments of this function (draw_networkx)
+ valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - {
+ "G",
+ "pos",
+ "arrows",
+ "with_labels",
+ }
if any([k not in valid_kwds for k in kwds]):
invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])