diff options
author | Adam Li <adam2392@gmail.com> | 2022-08-23 11:31:58 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-23 08:31:58 -0700 |
commit | df9a128f4171d95671e5d9f5460970cc4bf8e3b3 (patch) | |
tree | dc06afbc3f62c22efaf958739e517fa79001395d | |
parent | 88245f69f89dbee75cef67bdf35bbfb986a42d52 (diff) | |
download | networkx-df9a128f4171d95671e5d9f5460970cc4bf8e3b3.tar.gz |
[ENH] Find and verify a minimal D-separating set in DAG (#5898)
* Ran black
* Add unit tests
* Rename and fix citation
* Black
* Fix unite tests
* Isort
* Add algo description
* Update networkx/algorithms/tests/test_d_separation.py
* Update networkx/algorithms/traversal/breadth_first_search.py
* Address dans comments
* Fix unit tests
* Update networkx/algorithms/tests/test_d_separation.py
Co-authored-by: Dan Schult <dschult@colgate.edu>
* Apply suggestions from code review
Co-authored-by: Dan Schult <dschult@colgate.edu>
* Update networkx/algorithms/dag.py
Co-authored-by: Dan Schult <dschult@colgate.edu>
* Update networkx/algorithms/dag.py
Co-authored-by: Dan Schult <dschult@colgate.edu>
* Fix comments
* Clean up the docs a bit more
* Merge
Co-authored-by: Dan Schult <dschult@colgate.edu>
-rw-r--r-- | doc/release/release_dev.rst | 3 | ||||
-rw-r--r-- | networkx/algorithms/d_separation.py | 301 | ||||
-rw-r--r-- | networkx/algorithms/dag.py | 33 | ||||
-rw-r--r-- | networkx/algorithms/tests/test_d_separation.py | 50 | ||||
-rw-r--r-- | networkx/algorithms/tests/test_dag.py | 19 |
5 files changed, 401 insertions, 5 deletions
diff --git a/doc/release/release_dev.rst b/doc/release/release_dev.rst index 3fba4060..8b2329f4 100644 --- a/doc/release/release_dev.rst +++ b/doc/release/release_dev.rst @@ -54,6 +54,9 @@ Improvements This fixes a bug related for ``mapping=str`` and may change the behavior for other ``mapping`` arguments that implement both ``__getitem__`` and ``__call__``. +- [`#5898 <https://github.com/networkx/networkx/pull/5898>`_] + Implements computing and checking for minimal d-separators between two nodes. + Also adds functionality to DAGs for computing v-structures. API Changes ----------- diff --git a/networkx/algorithms/d_separation.py b/networkx/algorithms/d_separation.py index caf26d0f..ce7fe310 100644 --- a/networkx/algorithms/d_separation.py +++ b/networkx/algorithms/d_separation.py @@ -11,6 +11,65 @@ The implementation is based on the conceptually simple linear time algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of alternative algorithms. +Here, we provide a brief overview of d-separation and related concepts that +are relevant for understanding it: + +Blocking paths +-------------- + +Before we overview, we introduce the following terminology to describe paths: + +- "open" path: A path between two nodes that can be traversed +- "blocked" path: A path between two nodes that cannot be traversed + +A **collider** is a triplet of nodes along a path that is like the following: +``... u -> c <- v ...``), where 'c' is a common successor of ``u`` and ``v``. A path +through a collider is considered "blocked". When +a node that is a collider, or a descendant of a collider is included in +the d-separating set, then the path through that collider node is "open". If the +path through the collider node is open, then we will call this node an open collider. + +The d-separation set blocks the paths between ``u`` and ``v``. If you include colliders, +or their descendant nodes in the d-separation set, then those colliders will open up, +enabling a path to be traversed if it is not blocked some other way. + +Illustration of D-separation with examples +------------------------------------------ + +For a pair of two nodes, ``u`` and ``v``, all paths are considered open if +there is a path between ``u`` and ``v`` that is not blocked. That means, there is an open +path between ``u`` and ``v`` that does not encounter a collider, or a variable in the +d-separating set. + +For example, if the d-separating set is the empty set, then the following paths are +unblocked between ``u`` and ``v``: + +- u <- z -> v +- u -> w -> ... -> z -> v + +If for example, 'z' is in the d-separating set, then 'z' blocks those paths +between ``u`` and ``v``. + +Colliders block a path by default if they and their descendants are not included +in the d-separating set. An example of a path that is blocked when the d-separating +set is empty is: + +- u -> w -> ... -> z <- v + +because 'z' is a collider in this path and 'z' is not in the d-separating set. However, +if 'z' or a descendant of 'z' is included in the d-separating set, then the path through +the collider at 'z' (... -> z <- ...) is now "open". + +D-separation is concerned with blocking all paths between u and v. Therefore, a +d-separating set between ``u`` and ``v`` is one where all paths are blocked. + +D-separation and its applications in probability +------------------------------------------------ + +D-separation is commonly used in probabilistic graphical models. D-separation +connects the idea of probabilistic "dependence" with separation in a graph. If +one assumes the causal Markov condition [5]_, then d-separation implies conditional +independence in probability distributions. Examples -------- @@ -55,6 +114,8 @@ References .. [4] Koller, D., & Friedman, N. (2009). Probabilistic graphical models: principles and techniques. The MIT Press. +.. [5] https://en.wikipedia.org/wiki/Causal_Markov_condition + """ from collections import deque @@ -62,7 +123,7 @@ from collections import deque import networkx as nx from networkx.utils import UnionFind, not_implemented_for -__all__ = ["d_separated"] +__all__ = ["d_separated", "minimal_d_separator", "is_minimal_d_separator"] @not_implemented_for("undirected") @@ -100,6 +161,15 @@ def d_separated(G, x, y, z): If any of the input nodes are not found in the graph, a :exc:`NodeNotFound` exception is raised. + Notes + ----- + A d-separating set in a DAG is a set of nodes that + blocks all paths between the two sets. Nodes in `z` + block a path if they are part of the path and are not a collider, + or a descendant of a collider. A collider structure along a path + is ``... -> c <- ...`` where ``c`` is the collider node. + + https://en.wikipedia.org/wiki/Bayesian_network#d-separation """ if not nx.is_directed_acyclic_graph(G): @@ -140,3 +210,232 @@ def d_separated(G, x, y, z): return False else: return True + + +@not_implemented_for("undirected") +def minimal_d_separator(G, u, v): + """Compute a minimal d-separating set between 'u' and 'v'. + + A d-separating set in a DAG is a set of nodes that blocks all paths + between the two nodes, 'u' and 'v'. This function + constructs a d-separating set that is "minimal", meaning it is the smallest + d-separating set for 'u' and 'v'. This is not necessarily + unique. For more details, see Notes. + + Parameters + ---------- + G : graph + A networkx DAG. + u : node + A node in the graph, G. + v : node + A node in the graph, G. + + Raises + ------ + NetworkXError + Raises a :exc:`NetworkXError` if the input graph is not a DAG. + + NodeNotFound + If any of the input nodes are not found in the graph, + a :exc:`NodeNotFound` exception is raised. + + References + ---------- + .. [1] Tian, J., & Paz, A. (1998). Finding Minimal D-separators. + + Notes + ----- + This function only finds ``a`` minimal d-separator. It does not guarantee + uniqueness, since in a DAG there may be more than one minimal d-separator + between two nodes. Moreover, this only checks for minimal separators + between two nodes, not two sets. Finding minimal d-separators between + two sets of nodes is not supported. + + Uses the algorithm presented in [1]_. The complexity of the algorithm + is :math:`O(|E_{An}^m|)`, where :math:`|E_{An}^m|` stands for the + number of edges in the moralized graph of the sub-graph consisting + of only the ancestors of 'u' and 'v'. For full details, see [1]_. + + The algorithm works by constructing the moral graph consisting of just + the ancestors of `u` and `v`. Then it constructs a candidate for + a separating set ``Z'`` from the predecessors of `u` and `v`. + Then BFS is run starting from `u` and marking nodes + found from ``Z'`` and calling those nodes ``Z''``. + Then BFS is run again starting from `v` and marking nodes if they are + present in ``Z''``. Those marked nodes are the returned minimal + d-separating set. + + https://en.wikipedia.org/wiki/Bayesian_network#d-separation + """ + if not nx.is_directed_acyclic_graph(G): + raise nx.NetworkXError("graph should be directed acyclic") + + union_uv = {u, v} + + if any(n not in G.nodes for n in union_uv): + raise nx.NodeNotFound("one or more specified nodes not found in the graph") + + # first construct the set of ancestors of X and Y + x_anc = nx.ancestors(G, u) + y_anc = nx.ancestors(G, v) + D_anc_xy = x_anc.union(y_anc) + D_anc_xy.update((u, v)) + + # second, construct the moralization of the subgraph of Anc(X,Y) + moral_G = nx.moral_graph(G.subgraph(D_anc_xy)) + + # find a separating set Z' in moral_G + Z_prime = set(G.predecessors(u)).union(set(G.predecessors(v))) + + # perform BFS on the graph from 'x' to mark + Z_dprime = _bfs_with_marks(moral_G, u, Z_prime) + Z = _bfs_with_marks(moral_G, v, Z_dprime) + return Z + + +@not_implemented_for("undirected") +def is_minimal_d_separator(G, u, v, z): + """Determine if a d-separating set is minimal. + + A d-separating set, `z`, in a DAG is a set of nodes that blocks + all paths between the two nodes, `u` and `v`. This function + verifies that a set is "minimal", meaning there is no smaller + d-separating set between the two nodes. + + Parameters + ---------- + G : nx.DiGraph + The graph. + u : node + A node in the graph. + v : node + A node in the graph. + z : Set of nodes + The set of nodes to check if it is a minimal d-separating set. + + Returns + ------- + bool + Whether or not the `z` separating set is minimal. + + Raises + ------ + NetworkXError + Raises a :exc:`NetworkXError` if the input graph is not a DAG. + + NodeNotFound + If any of the input nodes are not found in the graph, + a :exc:`NodeNotFound` exception is raised. + + References + ---------- + .. [1] Tian, J., & Paz, A. (1998). Finding Minimal D-separators. + + Notes + ----- + This function only works on verifying a d-separating set is minimal + between two nodes. To verify that a d-separating set is minimal between + two sets of nodes is not supported. + + Uses algorithm 2 presented in [1]_. The complexity of the algorithm + is :math:`O(|E_{An}^m|)`, where :math:`|E_{An}^m|` stands for the + number of edges in the moralized graph of the sub-graph consisting + of only the ancestors of ``u`` and ``v``. + + The algorithm works by constructing the moral graph consisting of just + the ancestors of `u` and `v`. First, it performs BFS on the moral graph + starting from `u` and marking any nodes it encounters that are part of + the separating set, `z`. If a node is marked, then it does not continue + along that path. In the second stage, BFS with markings is repeated on the + moral graph starting from `v`. If at any stage, any node in `z` is + not marked, then `z` is considered not minimal. If the end of the algorithm + is reached, then `z` is minimal. + + For full details, see [1]_. + + https://en.wikipedia.org/wiki/Bayesian_network#d-separation + """ + if not nx.is_directed_acyclic_graph(G): + raise nx.NetworkXError("graph should be directed acyclic") + + union_uv = {u, v} + union_uv.update(z) + + if any(n not in G.nodes for n in union_uv): + raise nx.NodeNotFound("one or more specified nodes not found in the graph") + + x_anc = nx.ancestors(G, u) + y_anc = nx.ancestors(G, v) + xy_anc = x_anc.union(y_anc) + + # if Z contains any node which is not in ancestors of X or Y + # then it is definitely not minimal + if any(node not in xy_anc for node in z): + return False + + D_anc_xy = x_anc.union(y_anc) + D_anc_xy.update((u, v)) + + # second, construct the moralization of the subgraph + moral_G = nx.moral_graph(G.subgraph(D_anc_xy)) + + # start BFS from X + marks = _bfs_with_marks(moral_G, u, z) + + # if not all the Z is marked, then the set is not minimal + if any(node not in marks for node in z): + return False + + # similarly, start BFS from Y and check the marks + marks = _bfs_with_marks(moral_G, v, z) + # if not all the Z is marked, then the set is not minimal + if any(node not in marks for node in z): + return False + + return True + + +@not_implemented_for("directed") +def _bfs_with_marks(G, start_node, check_set): + """Breadth-first-search with markings. + + Performs BFS starting from ``start_node`` and whenever a node + inside ``check_set`` is met, it is "marked". Once a node is marked, + BFS does not continue along that path. The resulting marked nodes + are returned. + + Parameters + ---------- + G : nx.Graph + An undirected graph. + start_node : node + The start of the BFS. + check_set : set + The set of nodes to check against. + + Returns + ------- + marked : set + A set of nodes that were marked. + """ + visited = dict() + marked = set() + queue = [] + + visited[start_node] = None + queue.append(start_node) + while queue: + m = queue.pop(0) + + for nbr in G.neighbors(m): + if nbr not in visited: + # memoize where we visited so far + visited[nbr] = None + + # mark the node in Z' and do not continue along that path + if nbr in check_set: + marked.add(nbr) + else: + queue.append(nbr) + return marked diff --git a/networkx/algorithms/dag.py b/networkx/algorithms/dag.py index d1a33555..826b87ff 100644 --- a/networkx/algorithms/dag.py +++ b/networkx/algorithms/dag.py @@ -8,7 +8,7 @@ to the user to check for that. import heapq from collections import deque from functools import partial -from itertools import chain, product, starmap +from itertools import chain, combinations, product, starmap from math import gcd import networkx as nx @@ -30,6 +30,7 @@ __all__ = [ "dag_longest_path", "dag_longest_path_length", "dag_to_branching", + "compute_v_structures", ] chaini = chain.from_iterable @@ -1187,3 +1188,33 @@ def dag_to_branching(G): B.remove_node(0) B.remove_node(-1) return B + + +@not_implemented_for("undirected") +def compute_v_structures(G): + """Iterate through the graph to compute all v-structures. + + V-structures are triples in the directed graph where + two parent nodes point to the same child and the two parent nodes + are not adjacent. + + Parameters + ---------- + G : graph + A networkx DiGraph. + + Returns + ------- + vstructs : iterator of tuples + The v structures within the graph. Each v structure is a 3-tuple with the + parent, collider, and other parent. + + Notes + ----- + https://en.wikipedia.org/wiki/Collider_(statistics) + """ + for collider, preds in G.pred.items(): + for common_parents in combinations(preds, r=2): + # ensure that the colliders are the same + common_parents = sorted(common_parents) + yield (common_parents[0], collider, common_parents[1]) diff --git a/networkx/algorithms/tests/test_d_separation.py b/networkx/algorithms/tests/test_d_separation.py index 23367a00..74c16ae2 100644 --- a/networkx/algorithms/tests/test_d_separation.py +++ b/networkx/algorithms/tests/test_d_separation.py @@ -132,11 +132,16 @@ def test_undirected_graphs_are_not_supported(): """ Test that undirected graphs are not supported. - d-separation does not apply in the case of undirected graphs. + d-separation and its related algorithms do not apply in + the case of undirected graphs. """ + g = nx.path_graph(3, nx.Graph) with pytest.raises(nx.NetworkXNotImplemented): - g = nx.path_graph(3, nx.Graph) nx.d_separated(g, {0}, {1}, {2}) + with pytest.raises(nx.NetworkXNotImplemented): + nx.is_minimal_d_separator(g, {0}, {1}, {2}) + with pytest.raises(nx.NetworkXNotImplemented): + nx.minimal_d_separator(g, {0}, {1}) def test_cyclic_graphs_raise_error(): @@ -145,9 +150,13 @@ def test_cyclic_graphs_raise_error(): This is because PGMs assume a directed acyclic graph. """ + g = nx.cycle_graph(3, nx.DiGraph) with pytest.raises(nx.NetworkXError): - g = nx.cycle_graph(3, nx.DiGraph) nx.d_separated(g, {0}, {1}, {2}) + with pytest.raises(nx.NetworkXError): + nx.minimal_d_separator(g, {0}, {1}) + with pytest.raises(nx.NetworkXError): + nx.is_minimal_d_separator(g, {0}, {1}, {2}) def test_invalid_nodes_raise_error(asia_graph): @@ -156,3 +165,38 @@ def test_invalid_nodes_raise_error(asia_graph): """ with pytest.raises(nx.NodeNotFound): nx.d_separated(asia_graph, {0}, {1}, {2}) + with pytest.raises(nx.NodeNotFound): + nx.is_minimal_d_separator(asia_graph, 0, 1, {2}) + with pytest.raises(nx.NodeNotFound): + nx.minimal_d_separator(asia_graph, 0, 1) + + +def test_minimal_d_separator(): + # Case 1: + # create a graph A -> B <- C + # B -> D -> E; + # B -> F; + # G -> E; + edge_list = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("B", "F"), ("G", "E")] + G = nx.DiGraph(edge_list) + assert not nx.d_separated(G, {"B"}, {"E"}, set()) + + # minimal set of the corresponding graph + # for B and E should be (D,) + Zmin = nx.minimal_d_separator(G, "B", "E") + + # the minimal separating set should pass the test for minimality + assert nx.is_minimal_d_separator(G, "B", "E", Zmin) + assert Zmin == {"D"} + + # Case 2: + # create a graph A -> B -> C + # B -> D -> C; + edge_list = [("A", "B"), ("B", "C"), ("B", "D"), ("D", "C")] + G = nx.DiGraph(edge_list) + assert not nx.d_separated(G, {"A"}, {"C"}, set()) + Zmin = nx.minimal_d_separator(G, "A", "C") + + # the minimal separating set should pass the test for minimality + assert nx.is_minimal_d_separator(G, "A", "C", Zmin) + assert Zmin == {"B"} diff --git a/networkx/algorithms/tests/test_dag.py b/networkx/algorithms/tests/test_dag.py index b39b0334..56f16c4f 100644 --- a/networkx/algorithms/tests/test_dag.py +++ b/networkx/algorithms/tests/test_dag.py @@ -708,3 +708,22 @@ def test_ancestors_descendants_undirected(): undirected graphs.""" G = nx.path_graph(5) nx.ancestors(G, 2) == nx.descendants(G, 2) == {0, 1, 3, 4} + + +def test_compute_v_structures_raise(): + G = nx.Graph() + pytest.raises(nx.NetworkXNotImplemented, nx.compute_v_structures, G) + + +def test_compute_v_structures(): + edges = [(0, 1), (0, 2), (3, 2)] + G = nx.DiGraph(edges) + + v_structs = set(nx.compute_v_structures(G)) + assert len(v_structs) == 1 + assert (0, 2, 3) in v_structs + + edges = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("G", "E")] + G = nx.DiGraph(edges) + v_structs = set(nx.compute_v_structures(G)) + assert len(v_structs) == 2 |