diff options
author | Berlin Cho <berlinchoose@gmail.com> | 2021-06-15 01:31:59 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-14 20:31:59 +0300 |
commit | b442619bd60651574b733e06af79d8da17c93254 (patch) | |
tree | 2150a4802315c38bef21944f5f244f5123455308 | |
parent | c73d4ab4e30c5d0dea0aa58e6d22ccb20061fbe3 (diff) | |
download | networkx-b442619bd60651574b733e06af79d8da17c93254.tar.gz |
bugfix-for-issue-4353: modify default edge_id format (#4842)
Allow selection of an edge attribute to be the edge id when writing GraphML.
Adds edge_id_from_attribute kwarg.
-rw-r--r-- | networkx/readwrite/graphml.py | 61 | ||||
-rw-r--r-- | networkx/readwrite/tests/test_graphml.py | 102 |
2 files changed, 158 insertions, 5 deletions
diff --git a/networkx/readwrite/graphml.py b/networkx/readwrite/graphml.py index 1d3ec6df..5535bbc5 100644 --- a/networkx/readwrite/graphml.py +++ b/networkx/readwrite/graphml.py @@ -61,6 +61,7 @@ def write_graphml_xml( prettyprint=True, infer_numeric_types=False, named_key_ids=False, + edge_id_from_attribute=None, ): """Write G in GraphML XML format to path @@ -81,6 +82,10 @@ def write_graphml_xml( we infer in GraphML that both are floats. named_key_ids : bool (optional) If True use attr.name as value for key elements' id attribute. + edge_id_from_attribute : dict key (optional) + If provided, the graphml edge id is set by looking up the corresponding + edge data attribute keyed by this parameter. If `None` or the key does not exist in edge data, + the edge id is set by the edge key if `G` is a MultiGraph, else the edge id is left unset. Examples -------- @@ -97,6 +102,7 @@ def write_graphml_xml( prettyprint=prettyprint, infer_numeric_types=infer_numeric_types, named_key_ids=named_key_ids, + edge_id_from_attribute=edge_id_from_attribute, ) writer.add_graph_element(G) writer.dump(path) @@ -110,6 +116,7 @@ def write_graphml_lxml( prettyprint=True, infer_numeric_types=False, named_key_ids=False, + edge_id_from_attribute=None, ): """Write G in GraphML XML format to path @@ -133,6 +140,10 @@ def write_graphml_lxml( we infer in GraphML that both are floats. named_key_ids : bool (optional) If True use attr.name as value for key elements' id attribute. + edge_id_from_attribute : dict key (optional) + If provided, the graphml edge id is set by looking up the corresponding + edge data attribute keyed by this parameter. If `None` or the key does not exist in edge data, + the edge id is set by the edge key if `G` is a MultiGraph, else the edge id is left unset. Examples -------- @@ -148,7 +159,13 @@ def write_graphml_lxml( import lxml.etree as lxmletree except ImportError: return write_graphml_xml( - G, path, encoding, prettyprint, infer_numeric_types, named_key_ids + G, + path, + encoding, + prettyprint, + infer_numeric_types, + named_key_ids, + edge_id_from_attribute, ) writer = GraphMLWriterLxml( @@ -158,11 +175,18 @@ def write_graphml_lxml( prettyprint=prettyprint, infer_numeric_types=infer_numeric_types, named_key_ids=named_key_ids, + edge_id_from_attribute=edge_id_from_attribute, ) writer.dump() -def generate_graphml(G, encoding="utf-8", prettyprint=True, named_key_ids=False): +def generate_graphml( + G, + encoding="utf-8", + prettyprint=True, + named_key_ids=False, + edge_id_from_attribute=None, +): """Generate GraphML lines for G Parameters @@ -175,6 +199,10 @@ def generate_graphml(G, encoding="utf-8", prettyprint=True, named_key_ids=False) If True use line breaks and indenting in output XML. named_key_ids : bool (optional) If True use attr.name as value for key elements' id attribute. + edge_id_from_attribute : dict key (optional) + If provided, the graphml edge id is set by looking up the corresponding + edge data attribute keyed by this parameter. If `None` or the key does not exist in edge data, + the edge id is set by the edge key if `G` is a MultiGraph, else the edge id is left unset. Examples -------- @@ -190,7 +218,10 @@ def generate_graphml(G, encoding="utf-8", prettyprint=True, named_key_ids=False) edges together) hyperedges, nested graphs, or ports. """ writer = GraphMLWriter( - encoding=encoding, prettyprint=prettyprint, named_key_ids=named_key_ids + encoding=encoding, + prettyprint=prettyprint, + named_key_ids=named_key_ids, + edge_id_from_attribute=edge_id_from_attribute, ) writer.add_graph_element(G) yield from str(writer).splitlines() @@ -419,6 +450,7 @@ class GraphMLWriter(GraphML): prettyprint=True, infer_numeric_types=False, named_key_ids=False, + edge_id_from_attribute=None, ): self.construct_types() from xml.etree.ElementTree import Element @@ -428,6 +460,7 @@ class GraphMLWriter(GraphML): self.infer_numeric_types = infer_numeric_types self.prettyprint = prettyprint self.named_key_ids = named_key_ids + self.edge_id_from_attribute = edge_id_from_attribute self.encoding = encoding self.xml = self.myElement( "graphml", @@ -536,14 +569,30 @@ class GraphMLWriter(GraphML): if G.is_multigraph(): for u, v, key, data in G.edges(data=True, keys=True): edge_element = self.myElement( - "edge", source=str(u), target=str(v), id=str(key) + "edge", + source=str(u), + target=str(v), + id=str(data.get(self.edge_id_from_attribute)) + if self.edge_id_from_attribute + and self.edge_id_from_attribute in data + else str(key), ) default = G.graph.get("edge_default", {}) self.add_attributes("edge", edge_element, data, default) graph_element.append(edge_element) else: for u, v, data in G.edges(data=True): - edge_element = self.myElement("edge", source=str(u), target=str(v)) + if self.edge_id_from_attribute and self.edge_id_from_attribute in data: + # select attribute to be edge id + edge_element = self.myElement( + "edge", + source=str(u), + target=str(v), + id=str(data.get(self.edge_id_from_attribute)), + ) + else: + # default: no edge id + edge_element = self.myElement("edge", source=str(u), target=str(v)) default = G.graph.get("edge_default", {}) self.add_attributes("edge", edge_element, data, default) graph_element.append(edge_element) @@ -641,6 +690,7 @@ class GraphMLWriterLxml(GraphMLWriter): prettyprint=True, infer_numeric_types=False, named_key_ids=False, + edge_id_from_attribute=None, ): self.construct_types() import lxml.etree as lxmletree @@ -650,6 +700,7 @@ class GraphMLWriterLxml(GraphMLWriter): self._encoding = encoding self._prettyprint = prettyprint self.named_key_ids = named_key_ids + self.edge_id_from_attribute = edge_id_from_attribute self.infer_numeric_types = infer_numeric_types self._xml_base = lxmletree.xmlfile(path, encoding=encoding) diff --git a/networkx/readwrite/tests/test_graphml.py b/networkx/readwrite/tests/test_graphml.py index da3d2974..4e6064a2 100644 --- a/networkx/readwrite/tests/test_graphml.py +++ b/networkx/readwrite/tests/test_graphml.py @@ -1321,6 +1321,108 @@ class TestWriteGraphML(BaseGraphML): os.close(fd) os.unlink(fname) + def test_write_generate_edge_id_from_attribute(self): + from xml.etree.ElementTree import parse + + G = nx.Graph() + G.add_edges_from([("a", "b"), ("b", "c"), ("a", "c")]) + edge_attributes = {e: str(e) for e in G.edges} + nx.set_edge_attributes(G, edge_attributes, "eid") + fd, fname = tempfile.mkstemp() + # set edge_id_from_attribute e.g. "eid" for write_graphml() + self.writer(G, fname, edge_id_from_attribute="eid") + # set edge_id_from_attribute e.g. "eid" for generate_graphml() + generator = nx.generate_graphml(G, edge_id_from_attribute="eid") + + H = nx.read_graphml(fname) + assert nodes_equal(G.nodes(), H.nodes()) + assert edges_equal(G.edges(), H.edges()) + # NetworkX adds explicit edge "id" from file as attribute + nx.set_edge_attributes(G, edge_attributes, "id") + assert edges_equal(G.edges(data=True), H.edges(data=True)) + + tree = parse(fname) + children = list(tree.getroot()) + assert len(children) == 2 + edge_ids = [ + edge.attrib["id"] + for edge in tree.getroot().findall( + ".//{http://graphml.graphdrawing.org/xmlns}edge" + ) + ] + # verify edge id value is equal to sepcified attribute value + assert sorted(edge_ids) == sorted(edge_attributes.values()) + + # check graphml generated from generate_graphml() + data = "".join(generator) + J = nx.parse_graphml(data) + assert sorted(G.nodes()) == sorted(J.nodes()) + assert sorted(G.edges()) == sorted(J.edges()) + # NetworkX adds explicit edge "id" from file as attribute + nx.set_edge_attributes(G, edge_attributes, "id") + assert edges_equal(G.edges(data=True), J.edges(data=True)) + + os.close(fd) + os.unlink(fname) + + def test_multigraph_write_generate_edge_id_from_attribute(self): + from xml.etree.ElementTree import parse + + G = nx.MultiGraph() + G.add_edges_from([("a", "b"), ("b", "c"), ("a", "c"), ("a", "b")]) + edge_attributes = {e: str(e) for e in G.edges} + nx.set_edge_attributes(G, edge_attributes, "eid") + fd, fname = tempfile.mkstemp() + # set edge_id_from_attribute e.g. "eid" for write_graphml() + self.writer(G, fname, edge_id_from_attribute="eid") + # set edge_id_from_attribute e.g. "eid" for generate_graphml() + generator = nx.generate_graphml(G, edge_id_from_attribute="eid") + + H = nx.read_graphml(fname) + assert H.is_multigraph() + H = nx.read_graphml(fname, force_multigraph=True) + assert H.is_multigraph() + + assert nodes_equal(G.nodes(), H.nodes()) + assert edges_equal(G.edges(), H.edges()) + assert sorted([data.get("eid") for u, v, data in H.edges(data=True)]) == sorted( + edge_attributes.values() + ) + # NetworkX uses edge_ids as keys in multigraphs if no key + assert sorted([key for u, v, key in H.edges(keys=True)]) == sorted( + edge_attributes.values() + ) + + tree = parse(fname) + children = list(tree.getroot()) + assert len(children) == 2 + edge_ids = [ + edge.attrib["id"] + for edge in tree.getroot().findall( + ".//{http://graphml.graphdrawing.org/xmlns}edge" + ) + ] + # verify edge id value is equal to sepcified attribute value + assert sorted(edge_ids) == sorted(edge_attributes.values()) + + # check graphml generated from generate_graphml() + graphml_data = "".join(generator) + J = nx.parse_graphml(graphml_data) + assert J.is_multigraph() + + assert nodes_equal(G.nodes(), J.nodes()) + assert edges_equal(G.edges(), J.edges()) + assert sorted([data.get("eid") for u, v, data in J.edges(data=True)]) == sorted( + edge_attributes.values() + ) + # NetworkX uses edge_ids as keys in multigraphs if no key + assert sorted([key for u, v, key in J.edges(keys=True)]) == sorted( + edge_attributes.values() + ) + + os.close(fd) + os.unlink(fname) + def test_numpy_float64(self): np = pytest.importorskip("numpy") wt = np.float64(3.4) |