summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDan Schult <dschult@colgate.edu>2022-08-01 07:44:15 -0400
committerJarrod Millman <jarrod.millman@gmail.com>2022-08-21 09:29:54 -0700
commit0793d53d86cd45eab8ba188bcc7f11901eb56c17 (patch)
treeb25ad65af1e779d6d4adcc49985d66fb282df504
parent8a0a802570d25b73cee947507e0d04798515e9e0 (diff)
downloadnetworkx-0793d53d86cd45eab8ba188bcc7f11901eb56c17.tar.gz
Allow classes to relabel nodes -- casting (#5903)
-rw-r--r--networkx/relabel.py8
-rw-r--r--networkx/tests/test_relabel.py6
2 files changed, 12 insertions, 2 deletions
diff --git a/networkx/relabel.py b/networkx/relabel.py
index 80d973a2..ec341423 100644
--- a/networkx/relabel.py
+++ b/networkx/relabel.py
@@ -111,9 +111,13 @@ def relabel_nodes(G, mapping, copy=True):
--------
convert_node_labels_to_integers
"""
- # you can pass a function f(old_label)->new_label
+ # you can pass a function f(old_label) -> new_label
+ # or a class e.g. str(old_label) -> new_label
# but we'll just make a dictionary here regardless
- if not hasattr(mapping, "__getitem__"):
+ # To allow classes, we check if __getitem__ is a bound method using __self__
+ if not (
+ hasattr(mapping, "__getitem__") and hasattr(mapping.__getitem__, "__self__")
+ ):
m = {n: mapping(n) for n in G}
else:
m = mapping
diff --git a/networkx/tests/test_relabel.py b/networkx/tests/test_relabel.py
index a5d59f20..c30475b1 100644
--- a/networkx/tests/test_relabel.py
+++ b/networkx/tests/test_relabel.py
@@ -106,6 +106,12 @@ class TestRelabel:
H = nx.relabel_nodes(G, mapping)
assert nodes_equal(H.nodes(), [65, 66, 67, 68])
+ def test_relabel_nodes_classes(self):
+ G = nx.empty_graph()
+ G.add_edges_from([(0, 1), (0, 2), (1, 2), (2, 3)])
+ H = nx.relabel_nodes(G, str)
+ assert nodes_equal(H.nodes, ["0", "1", "2", "3"])
+
def test_relabel_nodes_graph(self):
G = nx.Graph([("A", "B"), ("A", "C"), ("B", "C"), ("C", "D")])
mapping = {"A": "aardvark", "B": "bear", "C": "cat", "D": "dog"}