summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBerlin Cho <berlinchoose@gmail.com>2021-06-15 01:31:59 +0800
committerGitHub <noreply@github.com>2021-06-14 20:31:59 +0300
commitb442619bd60651574b733e06af79d8da17c93254 (patch)
tree2150a4802315c38bef21944f5f244f5123455308
parentc73d4ab4e30c5d0dea0aa58e6d22ccb20061fbe3 (diff)
downloadnetworkx-b442619bd60651574b733e06af79d8da17c93254.tar.gz
bugfix-for-issue-4353: modify default edge_id format (#4842)
Allow selection of an edge attribute to be the edge id when writing GraphML. Adds edge_id_from_attribute kwarg.
-rw-r--r--networkx/readwrite/graphml.py61
-rw-r--r--networkx/readwrite/tests/test_graphml.py102
2 files changed, 158 insertions, 5 deletions
diff --git a/networkx/readwrite/graphml.py b/networkx/readwrite/graphml.py
index 1d3ec6df..5535bbc5 100644
--- a/networkx/readwrite/graphml.py
+++ b/networkx/readwrite/graphml.py
@@ -61,6 +61,7 @@ def write_graphml_xml(
prettyprint=True,
infer_numeric_types=False,
named_key_ids=False,
+ edge_id_from_attribute=None,
):
"""Write G in GraphML XML format to path
@@ -81,6 +82,10 @@ def write_graphml_xml(
we infer in GraphML that both are floats.
named_key_ids : bool (optional)
If True use attr.name as value for key elements' id attribute.
+ edge_id_from_attribute : dict key (optional)
+ If provided, the graphml edge id is set by looking up the corresponding
+ edge data attribute keyed by this parameter. If `None` or the key does not exist in edge data,
+ the edge id is set by the edge key if `G` is a MultiGraph, else the edge id is left unset.
Examples
--------
@@ -97,6 +102,7 @@ def write_graphml_xml(
prettyprint=prettyprint,
infer_numeric_types=infer_numeric_types,
named_key_ids=named_key_ids,
+ edge_id_from_attribute=edge_id_from_attribute,
)
writer.add_graph_element(G)
writer.dump(path)
@@ -110,6 +116,7 @@ def write_graphml_lxml(
prettyprint=True,
infer_numeric_types=False,
named_key_ids=False,
+ edge_id_from_attribute=None,
):
"""Write G in GraphML XML format to path
@@ -133,6 +140,10 @@ def write_graphml_lxml(
we infer in GraphML that both are floats.
named_key_ids : bool (optional)
If True use attr.name as value for key elements' id attribute.
+ edge_id_from_attribute : dict key (optional)
+ If provided, the graphml edge id is set by looking up the corresponding
+ edge data attribute keyed by this parameter. If `None` or the key does not exist in edge data,
+ the edge id is set by the edge key if `G` is a MultiGraph, else the edge id is left unset.
Examples
--------
@@ -148,7 +159,13 @@ def write_graphml_lxml(
import lxml.etree as lxmletree
except ImportError:
return write_graphml_xml(
- G, path, encoding, prettyprint, infer_numeric_types, named_key_ids
+ G,
+ path,
+ encoding,
+ prettyprint,
+ infer_numeric_types,
+ named_key_ids,
+ edge_id_from_attribute,
)
writer = GraphMLWriterLxml(
@@ -158,11 +175,18 @@ def write_graphml_lxml(
prettyprint=prettyprint,
infer_numeric_types=infer_numeric_types,
named_key_ids=named_key_ids,
+ edge_id_from_attribute=edge_id_from_attribute,
)
writer.dump()
-def generate_graphml(G, encoding="utf-8", prettyprint=True, named_key_ids=False):
+def generate_graphml(
+ G,
+ encoding="utf-8",
+ prettyprint=True,
+ named_key_ids=False,
+ edge_id_from_attribute=None,
+):
"""Generate GraphML lines for G
Parameters
@@ -175,6 +199,10 @@ def generate_graphml(G, encoding="utf-8", prettyprint=True, named_key_ids=False)
If True use line breaks and indenting in output XML.
named_key_ids : bool (optional)
If True use attr.name as value for key elements' id attribute.
+ edge_id_from_attribute : dict key (optional)
+ If provided, the graphml edge id is set by looking up the corresponding
+ edge data attribute keyed by this parameter. If `None` or the key does not exist in edge data,
+ the edge id is set by the edge key if `G` is a MultiGraph, else the edge id is left unset.
Examples
--------
@@ -190,7 +218,10 @@ def generate_graphml(G, encoding="utf-8", prettyprint=True, named_key_ids=False)
edges together) hyperedges, nested graphs, or ports.
"""
writer = GraphMLWriter(
- encoding=encoding, prettyprint=prettyprint, named_key_ids=named_key_ids
+ encoding=encoding,
+ prettyprint=prettyprint,
+ named_key_ids=named_key_ids,
+ edge_id_from_attribute=edge_id_from_attribute,
)
writer.add_graph_element(G)
yield from str(writer).splitlines()
@@ -419,6 +450,7 @@ class GraphMLWriter(GraphML):
prettyprint=True,
infer_numeric_types=False,
named_key_ids=False,
+ edge_id_from_attribute=None,
):
self.construct_types()
from xml.etree.ElementTree import Element
@@ -428,6 +460,7 @@ class GraphMLWriter(GraphML):
self.infer_numeric_types = infer_numeric_types
self.prettyprint = prettyprint
self.named_key_ids = named_key_ids
+ self.edge_id_from_attribute = edge_id_from_attribute
self.encoding = encoding
self.xml = self.myElement(
"graphml",
@@ -536,14 +569,30 @@ class GraphMLWriter(GraphML):
if G.is_multigraph():
for u, v, key, data in G.edges(data=True, keys=True):
edge_element = self.myElement(
- "edge", source=str(u), target=str(v), id=str(key)
+ "edge",
+ source=str(u),
+ target=str(v),
+ id=str(data.get(self.edge_id_from_attribute))
+ if self.edge_id_from_attribute
+ and self.edge_id_from_attribute in data
+ else str(key),
)
default = G.graph.get("edge_default", {})
self.add_attributes("edge", edge_element, data, default)
graph_element.append(edge_element)
else:
for u, v, data in G.edges(data=True):
- edge_element = self.myElement("edge", source=str(u), target=str(v))
+ if self.edge_id_from_attribute and self.edge_id_from_attribute in data:
+ # select attribute to be edge id
+ edge_element = self.myElement(
+ "edge",
+ source=str(u),
+ target=str(v),
+ id=str(data.get(self.edge_id_from_attribute)),
+ )
+ else:
+ # default: no edge id
+ edge_element = self.myElement("edge", source=str(u), target=str(v))
default = G.graph.get("edge_default", {})
self.add_attributes("edge", edge_element, data, default)
graph_element.append(edge_element)
@@ -641,6 +690,7 @@ class GraphMLWriterLxml(GraphMLWriter):
prettyprint=True,
infer_numeric_types=False,
named_key_ids=False,
+ edge_id_from_attribute=None,
):
self.construct_types()
import lxml.etree as lxmletree
@@ -650,6 +700,7 @@ class GraphMLWriterLxml(GraphMLWriter):
self._encoding = encoding
self._prettyprint = prettyprint
self.named_key_ids = named_key_ids
+ self.edge_id_from_attribute = edge_id_from_attribute
self.infer_numeric_types = infer_numeric_types
self._xml_base = lxmletree.xmlfile(path, encoding=encoding)
diff --git a/networkx/readwrite/tests/test_graphml.py b/networkx/readwrite/tests/test_graphml.py
index da3d2974..4e6064a2 100644
--- a/networkx/readwrite/tests/test_graphml.py
+++ b/networkx/readwrite/tests/test_graphml.py
@@ -1321,6 +1321,108 @@ class TestWriteGraphML(BaseGraphML):
os.close(fd)
os.unlink(fname)
+ def test_write_generate_edge_id_from_attribute(self):
+ from xml.etree.ElementTree import parse
+
+ G = nx.Graph()
+ G.add_edges_from([("a", "b"), ("b", "c"), ("a", "c")])
+ edge_attributes = {e: str(e) for e in G.edges}
+ nx.set_edge_attributes(G, edge_attributes, "eid")
+ fd, fname = tempfile.mkstemp()
+ # set edge_id_from_attribute e.g. "eid" for write_graphml()
+ self.writer(G, fname, edge_id_from_attribute="eid")
+ # set edge_id_from_attribute e.g. "eid" for generate_graphml()
+ generator = nx.generate_graphml(G, edge_id_from_attribute="eid")
+
+ H = nx.read_graphml(fname)
+ assert nodes_equal(G.nodes(), H.nodes())
+ assert edges_equal(G.edges(), H.edges())
+ # NetworkX adds explicit edge "id" from file as attribute
+ nx.set_edge_attributes(G, edge_attributes, "id")
+ assert edges_equal(G.edges(data=True), H.edges(data=True))
+
+ tree = parse(fname)
+ children = list(tree.getroot())
+ assert len(children) == 2
+ edge_ids = [
+ edge.attrib["id"]
+ for edge in tree.getroot().findall(
+ ".//{http://graphml.graphdrawing.org/xmlns}edge"
+ )
+ ]
+ # verify edge id value is equal to sepcified attribute value
+ assert sorted(edge_ids) == sorted(edge_attributes.values())
+
+ # check graphml generated from generate_graphml()
+ data = "".join(generator)
+ J = nx.parse_graphml(data)
+ assert sorted(G.nodes()) == sorted(J.nodes())
+ assert sorted(G.edges()) == sorted(J.edges())
+ # NetworkX adds explicit edge "id" from file as attribute
+ nx.set_edge_attributes(G, edge_attributes, "id")
+ assert edges_equal(G.edges(data=True), J.edges(data=True))
+
+ os.close(fd)
+ os.unlink(fname)
+
+ def test_multigraph_write_generate_edge_id_from_attribute(self):
+ from xml.etree.ElementTree import parse
+
+ G = nx.MultiGraph()
+ G.add_edges_from([("a", "b"), ("b", "c"), ("a", "c"), ("a", "b")])
+ edge_attributes = {e: str(e) for e in G.edges}
+ nx.set_edge_attributes(G, edge_attributes, "eid")
+ fd, fname = tempfile.mkstemp()
+ # set edge_id_from_attribute e.g. "eid" for write_graphml()
+ self.writer(G, fname, edge_id_from_attribute="eid")
+ # set edge_id_from_attribute e.g. "eid" for generate_graphml()
+ generator = nx.generate_graphml(G, edge_id_from_attribute="eid")
+
+ H = nx.read_graphml(fname)
+ assert H.is_multigraph()
+ H = nx.read_graphml(fname, force_multigraph=True)
+ assert H.is_multigraph()
+
+ assert nodes_equal(G.nodes(), H.nodes())
+ assert edges_equal(G.edges(), H.edges())
+ assert sorted([data.get("eid") for u, v, data in H.edges(data=True)]) == sorted(
+ edge_attributes.values()
+ )
+ # NetworkX uses edge_ids as keys in multigraphs if no key
+ assert sorted([key for u, v, key in H.edges(keys=True)]) == sorted(
+ edge_attributes.values()
+ )
+
+ tree = parse(fname)
+ children = list(tree.getroot())
+ assert len(children) == 2
+ edge_ids = [
+ edge.attrib["id"]
+ for edge in tree.getroot().findall(
+ ".//{http://graphml.graphdrawing.org/xmlns}edge"
+ )
+ ]
+ # verify edge id value is equal to sepcified attribute value
+ assert sorted(edge_ids) == sorted(edge_attributes.values())
+
+ # check graphml generated from generate_graphml()
+ graphml_data = "".join(generator)
+ J = nx.parse_graphml(graphml_data)
+ assert J.is_multigraph()
+
+ assert nodes_equal(G.nodes(), J.nodes())
+ assert edges_equal(G.edges(), J.edges())
+ assert sorted([data.get("eid") for u, v, data in J.edges(data=True)]) == sorted(
+ edge_attributes.values()
+ )
+ # NetworkX uses edge_ids as keys in multigraphs if no key
+ assert sorted([key for u, v, key in J.edges(keys=True)]) == sorted(
+ edge_attributes.values()
+ )
+
+ os.close(fd)
+ os.unlink(fname)
+
def test_numpy_float64(self):
np = pytest.importorskip("numpy")
wt = np.float64(3.4)