summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Li <adam2392@gmail.com>2022-08-23 11:31:58 -0400
committerGitHub <noreply@github.com>2022-08-23 08:31:58 -0700
commitdf9a128f4171d95671e5d9f5460970cc4bf8e3b3 (patch)
treedc06afbc3f62c22efaf958739e517fa79001395d
parent88245f69f89dbee75cef67bdf35bbfb986a42d52 (diff)
downloadnetworkx-df9a128f4171d95671e5d9f5460970cc4bf8e3b3.tar.gz
[ENH] Find and verify a minimal D-separating set in DAG (#5898)
* Ran black * Add unit tests * Rename and fix citation * Black * Fix unite tests * Isort * Add algo description * Update networkx/algorithms/tests/test_d_separation.py * Update networkx/algorithms/traversal/breadth_first_search.py * Address dans comments * Fix unit tests * Update networkx/algorithms/tests/test_d_separation.py Co-authored-by: Dan Schult <dschult@colgate.edu> * Apply suggestions from code review Co-authored-by: Dan Schult <dschult@colgate.edu> * Update networkx/algorithms/dag.py Co-authored-by: Dan Schult <dschult@colgate.edu> * Update networkx/algorithms/dag.py Co-authored-by: Dan Schult <dschult@colgate.edu> * Fix comments * Clean up the docs a bit more * Merge Co-authored-by: Dan Schult <dschult@colgate.edu>
-rw-r--r--doc/release/release_dev.rst3
-rw-r--r--networkx/algorithms/d_separation.py301
-rw-r--r--networkx/algorithms/dag.py33
-rw-r--r--networkx/algorithms/tests/test_d_separation.py50
-rw-r--r--networkx/algorithms/tests/test_dag.py19
5 files changed, 401 insertions, 5 deletions
diff --git a/doc/release/release_dev.rst b/doc/release/release_dev.rst
index 3fba4060..8b2329f4 100644
--- a/doc/release/release_dev.rst
+++ b/doc/release/release_dev.rst
@@ -54,6 +54,9 @@ Improvements
This fixes a bug related for ``mapping=str`` and may change the behavior for
other ``mapping`` arguments that implement both ``__getitem__`` and
``__call__``.
+- [`#5898 <https://github.com/networkx/networkx/pull/5898>`_]
+ Implements computing and checking for minimal d-separators between two nodes.
+ Also adds functionality to DAGs for computing v-structures.
API Changes
-----------
diff --git a/networkx/algorithms/d_separation.py b/networkx/algorithms/d_separation.py
index caf26d0f..ce7fe310 100644
--- a/networkx/algorithms/d_separation.py
+++ b/networkx/algorithms/d_separation.py
@@ -11,6 +11,65 @@ The implementation is based on the conceptually simple linear time
algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of
alternative algorithms.
+Here, we provide a brief overview of d-separation and related concepts that
+are relevant for understanding it:
+
+Blocking paths
+--------------
+
+Before we overview, we introduce the following terminology to describe paths:
+
+- "open" path: A path between two nodes that can be traversed
+- "blocked" path: A path between two nodes that cannot be traversed
+
+A **collider** is a triplet of nodes along a path that is like the following:
+``... u -> c <- v ...``), where 'c' is a common successor of ``u`` and ``v``. A path
+through a collider is considered "blocked". When
+a node that is a collider, or a descendant of a collider is included in
+the d-separating set, then the path through that collider node is "open". If the
+path through the collider node is open, then we will call this node an open collider.
+
+The d-separation set blocks the paths between ``u`` and ``v``. If you include colliders,
+or their descendant nodes in the d-separation set, then those colliders will open up,
+enabling a path to be traversed if it is not blocked some other way.
+
+Illustration of D-separation with examples
+------------------------------------------
+
+For a pair of two nodes, ``u`` and ``v``, all paths are considered open if
+there is a path between ``u`` and ``v`` that is not blocked. That means, there is an open
+path between ``u`` and ``v`` that does not encounter a collider, or a variable in the
+d-separating set.
+
+For example, if the d-separating set is the empty set, then the following paths are
+unblocked between ``u`` and ``v``:
+
+- u <- z -> v
+- u -> w -> ... -> z -> v
+
+If for example, 'z' is in the d-separating set, then 'z' blocks those paths
+between ``u`` and ``v``.
+
+Colliders block a path by default if they and their descendants are not included
+in the d-separating set. An example of a path that is blocked when the d-separating
+set is empty is:
+
+- u -> w -> ... -> z <- v
+
+because 'z' is a collider in this path and 'z' is not in the d-separating set. However,
+if 'z' or a descendant of 'z' is included in the d-separating set, then the path through
+the collider at 'z' (... -> z <- ...) is now "open".
+
+D-separation is concerned with blocking all paths between u and v. Therefore, a
+d-separating set between ``u`` and ``v`` is one where all paths are blocked.
+
+D-separation and its applications in probability
+------------------------------------------------
+
+D-separation is commonly used in probabilistic graphical models. D-separation
+connects the idea of probabilistic "dependence" with separation in a graph. If
+one assumes the causal Markov condition [5]_, then d-separation implies conditional
+independence in probability distributions.
Examples
--------
@@ -55,6 +114,8 @@ References
.. [4] Koller, D., & Friedman, N. (2009).
Probabilistic graphical models: principles and techniques. The MIT Press.
+.. [5] https://en.wikipedia.org/wiki/Causal_Markov_condition
+
"""
from collections import deque
@@ -62,7 +123,7 @@ from collections import deque
import networkx as nx
from networkx.utils import UnionFind, not_implemented_for
-__all__ = ["d_separated"]
+__all__ = ["d_separated", "minimal_d_separator", "is_minimal_d_separator"]
@not_implemented_for("undirected")
@@ -100,6 +161,15 @@ def d_separated(G, x, y, z):
If any of the input nodes are not found in the graph,
a :exc:`NodeNotFound` exception is raised.
+ Notes
+ -----
+ A d-separating set in a DAG is a set of nodes that
+ blocks all paths between the two sets. Nodes in `z`
+ block a path if they are part of the path and are not a collider,
+ or a descendant of a collider. A collider structure along a path
+ is ``... -> c <- ...`` where ``c`` is the collider node.
+
+ https://en.wikipedia.org/wiki/Bayesian_network#d-separation
"""
if not nx.is_directed_acyclic_graph(G):
@@ -140,3 +210,232 @@ def d_separated(G, x, y, z):
return False
else:
return True
+
+
+@not_implemented_for("undirected")
+def minimal_d_separator(G, u, v):
+ """Compute a minimal d-separating set between 'u' and 'v'.
+
+ A d-separating set in a DAG is a set of nodes that blocks all paths
+ between the two nodes, 'u' and 'v'. This function
+ constructs a d-separating set that is "minimal", meaning it is the smallest
+ d-separating set for 'u' and 'v'. This is not necessarily
+ unique. For more details, see Notes.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx DAG.
+ u : node
+ A node in the graph, G.
+ v : node
+ A node in the graph, G.
+
+ Raises
+ ------
+ NetworkXError
+ Raises a :exc:`NetworkXError` if the input graph is not a DAG.
+
+ NodeNotFound
+ If any of the input nodes are not found in the graph,
+ a :exc:`NodeNotFound` exception is raised.
+
+ References
+ ----------
+ .. [1] Tian, J., & Paz, A. (1998). Finding Minimal D-separators.
+
+ Notes
+ -----
+ This function only finds ``a`` minimal d-separator. It does not guarantee
+ uniqueness, since in a DAG there may be more than one minimal d-separator
+ between two nodes. Moreover, this only checks for minimal separators
+ between two nodes, not two sets. Finding minimal d-separators between
+ two sets of nodes is not supported.
+
+ Uses the algorithm presented in [1]_. The complexity of the algorithm
+ is :math:`O(|E_{An}^m|)`, where :math:`|E_{An}^m|` stands for the
+ number of edges in the moralized graph of the sub-graph consisting
+ of only the ancestors of 'u' and 'v'. For full details, see [1]_.
+
+ The algorithm works by constructing the moral graph consisting of just
+ the ancestors of `u` and `v`. Then it constructs a candidate for
+ a separating set ``Z'`` from the predecessors of `u` and `v`.
+ Then BFS is run starting from `u` and marking nodes
+ found from ``Z'`` and calling those nodes ``Z''``.
+ Then BFS is run again starting from `v` and marking nodes if they are
+ present in ``Z''``. Those marked nodes are the returned minimal
+ d-separating set.
+
+ https://en.wikipedia.org/wiki/Bayesian_network#d-separation
+ """
+ if not nx.is_directed_acyclic_graph(G):
+ raise nx.NetworkXError("graph should be directed acyclic")
+
+ union_uv = {u, v}
+
+ if any(n not in G.nodes for n in union_uv):
+ raise nx.NodeNotFound("one or more specified nodes not found in the graph")
+
+ # first construct the set of ancestors of X and Y
+ x_anc = nx.ancestors(G, u)
+ y_anc = nx.ancestors(G, v)
+ D_anc_xy = x_anc.union(y_anc)
+ D_anc_xy.update((u, v))
+
+ # second, construct the moralization of the subgraph of Anc(X,Y)
+ moral_G = nx.moral_graph(G.subgraph(D_anc_xy))
+
+ # find a separating set Z' in moral_G
+ Z_prime = set(G.predecessors(u)).union(set(G.predecessors(v)))
+
+ # perform BFS on the graph from 'x' to mark
+ Z_dprime = _bfs_with_marks(moral_G, u, Z_prime)
+ Z = _bfs_with_marks(moral_G, v, Z_dprime)
+ return Z
+
+
+@not_implemented_for("undirected")
+def is_minimal_d_separator(G, u, v, z):
+ """Determine if a d-separating set is minimal.
+
+ A d-separating set, `z`, in a DAG is a set of nodes that blocks
+ all paths between the two nodes, `u` and `v`. This function
+ verifies that a set is "minimal", meaning there is no smaller
+ d-separating set between the two nodes.
+
+ Parameters
+ ----------
+ G : nx.DiGraph
+ The graph.
+ u : node
+ A node in the graph.
+ v : node
+ A node in the graph.
+ z : Set of nodes
+ The set of nodes to check if it is a minimal d-separating set.
+
+ Returns
+ -------
+ bool
+ Whether or not the `z` separating set is minimal.
+
+ Raises
+ ------
+ NetworkXError
+ Raises a :exc:`NetworkXError` if the input graph is not a DAG.
+
+ NodeNotFound
+ If any of the input nodes are not found in the graph,
+ a :exc:`NodeNotFound` exception is raised.
+
+ References
+ ----------
+ .. [1] Tian, J., & Paz, A. (1998). Finding Minimal D-separators.
+
+ Notes
+ -----
+ This function only works on verifying a d-separating set is minimal
+ between two nodes. To verify that a d-separating set is minimal between
+ two sets of nodes is not supported.
+
+ Uses algorithm 2 presented in [1]_. The complexity of the algorithm
+ is :math:`O(|E_{An}^m|)`, where :math:`|E_{An}^m|` stands for the
+ number of edges in the moralized graph of the sub-graph consisting
+ of only the ancestors of ``u`` and ``v``.
+
+ The algorithm works by constructing the moral graph consisting of just
+ the ancestors of `u` and `v`. First, it performs BFS on the moral graph
+ starting from `u` and marking any nodes it encounters that are part of
+ the separating set, `z`. If a node is marked, then it does not continue
+ along that path. In the second stage, BFS with markings is repeated on the
+ moral graph starting from `v`. If at any stage, any node in `z` is
+ not marked, then `z` is considered not minimal. If the end of the algorithm
+ is reached, then `z` is minimal.
+
+ For full details, see [1]_.
+
+ https://en.wikipedia.org/wiki/Bayesian_network#d-separation
+ """
+ if not nx.is_directed_acyclic_graph(G):
+ raise nx.NetworkXError("graph should be directed acyclic")
+
+ union_uv = {u, v}
+ union_uv.update(z)
+
+ if any(n not in G.nodes for n in union_uv):
+ raise nx.NodeNotFound("one or more specified nodes not found in the graph")
+
+ x_anc = nx.ancestors(G, u)
+ y_anc = nx.ancestors(G, v)
+ xy_anc = x_anc.union(y_anc)
+
+ # if Z contains any node which is not in ancestors of X or Y
+ # then it is definitely not minimal
+ if any(node not in xy_anc for node in z):
+ return False
+
+ D_anc_xy = x_anc.union(y_anc)
+ D_anc_xy.update((u, v))
+
+ # second, construct the moralization of the subgraph
+ moral_G = nx.moral_graph(G.subgraph(D_anc_xy))
+
+ # start BFS from X
+ marks = _bfs_with_marks(moral_G, u, z)
+
+ # if not all the Z is marked, then the set is not minimal
+ if any(node not in marks for node in z):
+ return False
+
+ # similarly, start BFS from Y and check the marks
+ marks = _bfs_with_marks(moral_G, v, z)
+ # if not all the Z is marked, then the set is not minimal
+ if any(node not in marks for node in z):
+ return False
+
+ return True
+
+
+@not_implemented_for("directed")
+def _bfs_with_marks(G, start_node, check_set):
+ """Breadth-first-search with markings.
+
+ Performs BFS starting from ``start_node`` and whenever a node
+ inside ``check_set`` is met, it is "marked". Once a node is marked,
+ BFS does not continue along that path. The resulting marked nodes
+ are returned.
+
+ Parameters
+ ----------
+ G : nx.Graph
+ An undirected graph.
+ start_node : node
+ The start of the BFS.
+ check_set : set
+ The set of nodes to check against.
+
+ Returns
+ -------
+ marked : set
+ A set of nodes that were marked.
+ """
+ visited = dict()
+ marked = set()
+ queue = []
+
+ visited[start_node] = None
+ queue.append(start_node)
+ while queue:
+ m = queue.pop(0)
+
+ for nbr in G.neighbors(m):
+ if nbr not in visited:
+ # memoize where we visited so far
+ visited[nbr] = None
+
+ # mark the node in Z' and do not continue along that path
+ if nbr in check_set:
+ marked.add(nbr)
+ else:
+ queue.append(nbr)
+ return marked
diff --git a/networkx/algorithms/dag.py b/networkx/algorithms/dag.py
index d1a33555..826b87ff 100644
--- a/networkx/algorithms/dag.py
+++ b/networkx/algorithms/dag.py
@@ -8,7 +8,7 @@ to the user to check for that.
import heapq
from collections import deque
from functools import partial
-from itertools import chain, product, starmap
+from itertools import chain, combinations, product, starmap
from math import gcd
import networkx as nx
@@ -30,6 +30,7 @@ __all__ = [
"dag_longest_path",
"dag_longest_path_length",
"dag_to_branching",
+ "compute_v_structures",
]
chaini = chain.from_iterable
@@ -1187,3 +1188,33 @@ def dag_to_branching(G):
B.remove_node(0)
B.remove_node(-1)
return B
+
+
+@not_implemented_for("undirected")
+def compute_v_structures(G):
+ """Iterate through the graph to compute all v-structures.
+
+ V-structures are triples in the directed graph where
+ two parent nodes point to the same child and the two parent nodes
+ are not adjacent.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx DiGraph.
+
+ Returns
+ -------
+ vstructs : iterator of tuples
+ The v structures within the graph. Each v structure is a 3-tuple with the
+ parent, collider, and other parent.
+
+ Notes
+ -----
+ https://en.wikipedia.org/wiki/Collider_(statistics)
+ """
+ for collider, preds in G.pred.items():
+ for common_parents in combinations(preds, r=2):
+ # ensure that the colliders are the same
+ common_parents = sorted(common_parents)
+ yield (common_parents[0], collider, common_parents[1])
diff --git a/networkx/algorithms/tests/test_d_separation.py b/networkx/algorithms/tests/test_d_separation.py
index 23367a00..74c16ae2 100644
--- a/networkx/algorithms/tests/test_d_separation.py
+++ b/networkx/algorithms/tests/test_d_separation.py
@@ -132,11 +132,16 @@ def test_undirected_graphs_are_not_supported():
"""
Test that undirected graphs are not supported.
- d-separation does not apply in the case of undirected graphs.
+ d-separation and its related algorithms do not apply in
+ the case of undirected graphs.
"""
+ g = nx.path_graph(3, nx.Graph)
with pytest.raises(nx.NetworkXNotImplemented):
- g = nx.path_graph(3, nx.Graph)
nx.d_separated(g, {0}, {1}, {2})
+ with pytest.raises(nx.NetworkXNotImplemented):
+ nx.is_minimal_d_separator(g, {0}, {1}, {2})
+ with pytest.raises(nx.NetworkXNotImplemented):
+ nx.minimal_d_separator(g, {0}, {1})
def test_cyclic_graphs_raise_error():
@@ -145,9 +150,13 @@ def test_cyclic_graphs_raise_error():
This is because PGMs assume a directed acyclic graph.
"""
+ g = nx.cycle_graph(3, nx.DiGraph)
with pytest.raises(nx.NetworkXError):
- g = nx.cycle_graph(3, nx.DiGraph)
nx.d_separated(g, {0}, {1}, {2})
+ with pytest.raises(nx.NetworkXError):
+ nx.minimal_d_separator(g, {0}, {1})
+ with pytest.raises(nx.NetworkXError):
+ nx.is_minimal_d_separator(g, {0}, {1}, {2})
def test_invalid_nodes_raise_error(asia_graph):
@@ -156,3 +165,38 @@ def test_invalid_nodes_raise_error(asia_graph):
"""
with pytest.raises(nx.NodeNotFound):
nx.d_separated(asia_graph, {0}, {1}, {2})
+ with pytest.raises(nx.NodeNotFound):
+ nx.is_minimal_d_separator(asia_graph, 0, 1, {2})
+ with pytest.raises(nx.NodeNotFound):
+ nx.minimal_d_separator(asia_graph, 0, 1)
+
+
+def test_minimal_d_separator():
+ # Case 1:
+ # create a graph A -> B <- C
+ # B -> D -> E;
+ # B -> F;
+ # G -> E;
+ edge_list = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("B", "F"), ("G", "E")]
+ G = nx.DiGraph(edge_list)
+ assert not nx.d_separated(G, {"B"}, {"E"}, set())
+
+ # minimal set of the corresponding graph
+ # for B and E should be (D,)
+ Zmin = nx.minimal_d_separator(G, "B", "E")
+
+ # the minimal separating set should pass the test for minimality
+ assert nx.is_minimal_d_separator(G, "B", "E", Zmin)
+ assert Zmin == {"D"}
+
+ # Case 2:
+ # create a graph A -> B -> C
+ # B -> D -> C;
+ edge_list = [("A", "B"), ("B", "C"), ("B", "D"), ("D", "C")]
+ G = nx.DiGraph(edge_list)
+ assert not nx.d_separated(G, {"A"}, {"C"}, set())
+ Zmin = nx.minimal_d_separator(G, "A", "C")
+
+ # the minimal separating set should pass the test for minimality
+ assert nx.is_minimal_d_separator(G, "A", "C", Zmin)
+ assert Zmin == {"B"}
diff --git a/networkx/algorithms/tests/test_dag.py b/networkx/algorithms/tests/test_dag.py
index b39b0334..56f16c4f 100644
--- a/networkx/algorithms/tests/test_dag.py
+++ b/networkx/algorithms/tests/test_dag.py
@@ -708,3 +708,22 @@ def test_ancestors_descendants_undirected():
undirected graphs."""
G = nx.path_graph(5)
nx.ancestors(G, 2) == nx.descendants(G, 2) == {0, 1, 3, 4}
+
+
+def test_compute_v_structures_raise():
+ G = nx.Graph()
+ pytest.raises(nx.NetworkXNotImplemented, nx.compute_v_structures, G)
+
+
+def test_compute_v_structures():
+ edges = [(0, 1), (0, 2), (3, 2)]
+ G = nx.DiGraph(edges)
+
+ v_structs = set(nx.compute_v_structures(G))
+ assert len(v_structs) == 1
+ assert (0, 2, 3) in v_structs
+
+ edges = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("G", "E")]
+ G = nx.DiGraph(edges)
+ v_structs = set(nx.compute_v_structures(G))
+ assert len(v_structs) == 2