diff options
author | Ross Barnowski <rossbar@berkeley.edu> | 2022-08-14 02:59:11 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-13 16:59:11 -0700 |
commit | 19c1454d3dfa70a893ea67f2d78515658e8c08e5 (patch) | |
tree | 602a90d221072f0977f993f5a1dcec9a8a781ea2 | |
parent | e6d5909b88619d12f2a112db468e981fa68811d6 (diff) | |
download | networkx-19c1454d3dfa70a893ea67f2d78515658e8c08e5.tar.gz |
Replace LCA with naive implementations (#5883)
* WIP: Replace functions to evaluate tests.
* Raise prompt exceptions by wrapping generator.
* Fix erroneous ground-truth self-ancestor in tests.
* Move pair creation outside of generator and validate.
* Convert input with fromkeys to preserve order and rm duplicates.
* Replace LCA implementations & update tests.
* Test cleanup: move new tests into old class.
Allows us to get rid of duplication/another test setup.
* Rm naive fns from refguide.
* Add release note.
* Remove unused imports.
* Remove missed duplicate function (bad rebase).
Co-authored-by: Dilara Tekinoglu <dilaranurtuncturk@gmail.com>
-rw-r--r-- | doc/reference/algorithms/lowest_common_ancestors.rst | 2 | ||||
-rw-r--r-- | doc/release/release_dev.rst | 4 | ||||
-rw-r--r-- | networkx/algorithms/lowest_common_ancestors.py | 369 | ||||
-rw-r--r-- | networkx/algorithms/tests/test_lowest_common_ancestors.py | 88 |
4 files changed, 54 insertions, 409 deletions
diff --git a/doc/reference/algorithms/lowest_common_ancestors.rst b/doc/reference/algorithms/lowest_common_ancestors.rst index 855ecc3a..a82b3571 100644 --- a/doc/reference/algorithms/lowest_common_ancestors.rst +++ b/doc/reference/algorithms/lowest_common_ancestors.rst @@ -9,5 +9,3 @@ Lowest Common Ancestor all_pairs_lowest_common_ancestor tree_all_pairs_lowest_common_ancestor lowest_common_ancestor - naive_all_pairs_lowest_common_ancestor - naive_lowest_common_ancestor diff --git a/doc/release/release_dev.rst b/doc/release/release_dev.rst index 9bec2d31..22541a7d 100644 --- a/doc/release/release_dev.rst +++ b/doc/release/release_dev.rst @@ -43,6 +43,10 @@ Improvements ------------ - [`#5663 <https://github.com/networkx/networkx/pull/5663>`_] Implements edge swapping for directed graphs. +- [`#5663 <https://github.com/networkx/networkx/pull/5883>`_] + Replace the implementation of ``lowest_common_ancestor`` and + ``all_pairs_lowest_common_ancestor`` with a "naive" algorithm to fix + several bugs and improve performance. API Changes ----------- diff --git a/networkx/algorithms/lowest_common_ancestors.py b/networkx/algorithms/lowest_common_ancestors.py index ec515a90..68aaf7d3 100644 --- a/networkx/algorithms/lowest_common_ancestors.py +++ b/networkx/algorithms/lowest_common_ancestors.py @@ -1,7 +1,7 @@ """Algorithms for finding the lowest common ancestor of trees and DAGs.""" from collections import defaultdict from collections.abc import Mapping, Set -from itertools import chain, combinations_with_replacement, count +from itertools import combinations_with_replacement import networkx as nx from networkx.utils import UnionFind, arbitrary_element, not_implemented_for @@ -10,14 +10,12 @@ __all__ = [ "all_pairs_lowest_common_ancestor", "tree_all_pairs_lowest_common_ancestor", "lowest_common_ancestor", - "naive_lowest_common_ancestor", - "naive_all_pairs_lowest_common_ancestor", ] @not_implemented_for("undirected") @not_implemented_for("multigraph") -def naive_all_pairs_lowest_common_ancestor(G, pairs=None): +def all_pairs_lowest_common_ancestor(G, pairs=None): """Return the lowest common ancestor of all pairs or the provided pairs Parameters @@ -48,13 +46,13 @@ def naive_all_pairs_lowest_common_ancestor(G, pairs=None): possible combinations of nodes in `G`, including self-pairings: >>> G = nx.DiGraph([(0, 1), (0, 3), (1, 2)]) - >>> dict(nx.naive_all_pairs_lowest_common_ancestor(G)) + >>> dict(nx.all_pairs_lowest_common_ancestor(G)) {(0, 0): 0, (0, 1): 0, (0, 3): 0, (0, 2): 0, (1, 1): 1, (1, 3): 0, (1, 2): 1, (3, 3): 3, (3, 2): 0, (2, 2): 2} The pairs argument can be used to limit the output to only the specified node pairings: - >>> dict(nx.naive_all_pairs_lowest_common_ancestor(G, pairs=[(1, 2), (2, 3)])) + >>> dict(nx.all_pairs_lowest_common_ancestor(G, pairs=[(1, 2), (2, 3)])) {(1, 2): 1, (2, 3): 0} Notes @@ -63,46 +61,59 @@ def naive_all_pairs_lowest_common_ancestor(G, pairs=None): See Also -------- - naive_lowest_common_ancestor + lowest_common_ancestor """ if not nx.is_directed_acyclic_graph(G): raise nx.NetworkXError("LCA only defined on directed acyclic graphs.") if len(G) == 0: raise nx.NetworkXPointlessConcept("LCA meaningless on null graphs.") - ancestor_cache = {} - if pairs is None: - pairs = combinations_with_replacement(G, 2) - - for v, w in pairs: - if v not in ancestor_cache: - ancestor_cache[v] = nx.ancestors(G, v) - ancestor_cache[v].add(v) - if w not in ancestor_cache: - ancestor_cache[w] = nx.ancestors(G, w) - ancestor_cache[w].add(w) - - common_ancestors = ancestor_cache[v] & ancestor_cache[w] - - if common_ancestors: - common_ancestor = next(iter(common_ancestors)) - while True: - successor = None - for lower_ancestor in G.successors(common_ancestor): - if lower_ancestor in common_ancestors: - successor = lower_ancestor + else: + # Convert iterator to iterable, if necessary. Trim duplicates. + pairs = dict.fromkeys(pairs) + # Verify that each of the nodes in the provided pairs is in G + nodeset = set(G) + for pair in pairs: + if set(pair) - nodeset: + raise nx.NodeNotFound( + f"Node(s) {set(pair) - nodeset} from pair {pair} not in G." + ) + + # Once input validation is done, construct the generator + def generate_lca_from_pairs(G, pairs): + ancestor_cache = {} + + for v, w in pairs: + if v not in ancestor_cache: + ancestor_cache[v] = nx.ancestors(G, v) + ancestor_cache[v].add(v) + if w not in ancestor_cache: + ancestor_cache[w] = nx.ancestors(G, w) + ancestor_cache[w].add(w) + + common_ancestors = ancestor_cache[v] & ancestor_cache[w] + + if common_ancestors: + common_ancestor = next(iter(common_ancestors)) + while True: + successor = None + for lower_ancestor in G.successors(common_ancestor): + if lower_ancestor in common_ancestors: + successor = lower_ancestor + break + if successor is None: break - if successor is None: - break - common_ancestor = successor - yield ((v, w), common_ancestor) + common_ancestor = successor + yield ((v, w), common_ancestor) + + return generate_lca_from_pairs(G, pairs) @not_implemented_for("undirected") @not_implemented_for("multigraph") -def naive_lowest_common_ancestor(G, node1, node2, default=None): +def lowest_common_ancestor(G, node1, node2, default=None): """Compute the lowest common ancestor of the given pair of nodes. Parameters @@ -124,14 +135,14 @@ def naive_lowest_common_ancestor(G, node1, node2, default=None): >>> G = nx.DiGraph() >>> nx.add_path(G, (0, 1, 2, 3)) >>> nx.add_path(G, (0, 4, 3)) - >>> nx.naive_lowest_common_ancestor(G, 2, 4) + >>> nx.lowest_common_ancestor(G, 2, 4) 0 See Also -------- - naive_all_pairs_lowest_common_ancestor""" + all_pairs_lowest_common_ancestor""" - ans = list(naive_all_pairs_lowest_common_ancestor(G, pairs=[(node1, node2)])) + ans = list(all_pairs_lowest_common_ancestor(G, pairs=[(node1, node2)])) if ans: assert len(ans) == 1 return ans[0][1] @@ -254,289 +265,3 @@ def tree_all_pairs_lowest_common_ancestor(G, root=None, pairs=None): parent = arbitrary_element(G.pred[node]) uf.union(parent, node) ancestors[uf[parent]] = parent - - -@not_implemented_for("undirected") -@not_implemented_for("multigraph") -def lowest_common_ancestor(G, node1, node2, default=None): - """Compute the lowest common ancestor of the given pair of nodes. - - Parameters - ---------- - G : NetworkX directed graph - - node1, node2 : nodes in the graph. - - default : object - Returned if no common ancestor between `node1` and `node2` - - Returns - ------- - The lowest common ancestor of node1 and node2, - or default if they have no common ancestors. - - Examples - -------- - >>> G = nx.DiGraph([(0, 1), (0, 2), (2, 3), (2, 4), (1, 6), (4, 5)]) - >>> nx.lowest_common_ancestor(G, 3, 5) - 2 - - We can also set `default` argument as below. The value of default is returned - if there are no common ancestors of given two nodes. - - >>> G = nx.DiGraph([(4, 5), (12, 13)]) - >>> nx.lowest_common_ancestor(G, 12, 5, default="No common ancestors!") - 'No common ancestors!' - - Notes - ----- - Only defined on non-null directed acyclic graphs. - Takes n log(n) time in the size of the graph. - See `all_pairs_lowest_common_ancestor` when you have - more than one pair of nodes of interest. - - See Also - -------- - tree_all_pairs_lowest_common_ancestor - all_pairs_lowest_common_ancestor - """ - ans = list(all_pairs_lowest_common_ancestor(G, pairs=[(node1, node2)])) - if ans: - assert len(ans) == 1 - return ans[0][1] - return default - - -@not_implemented_for("undirected") -@not_implemented_for("multigraph") -def all_pairs_lowest_common_ancestor(G, pairs=None): - """Compute the lowest common ancestor for pairs of nodes. - - Parameters - ---------- - G : NetworkX directed graph - - pairs : iterable of pairs of nodes, optional (default: all pairs) - The pairs of nodes of interest. - If None, will find the LCA of all pairs of nodes. - - Yields - ------ - ((node1, node2), lca) : 2-tuple - Where lca is least common ancestor of node1 and node2. - Note that for the default case, the order of the node pair is not considered, - e.g. you will not get both ``(a, b)`` and ``(b, a)`` - - Raises - ------ - NetworkXPointlessConcept - If `G` is null. - NetworkXError - If `G` is not a DAG. - - Examples - -------- - The default behavior is to yield the lowest common ancestor for all - possible combinations of nodes in `G`, including self-pairings: - - >>> G = nx.DiGraph([(0, 1), (0, 3), (1, 2)]) - >>> dict(nx.all_pairs_lowest_common_ancestor(G)) - {(2, 2): 2, (1, 1): 1, (2, 1): 1, (1, 3): 0, (2, 3): 0, (3, 3): 3, (0, 0): 0, (1, 0): 0, (2, 0): 0, (3, 0): 0} - - The `pairs` argument can be used to limit the output to only the - specified node pairings: - - >>> dict(nx.all_pairs_lowest_common_ancestor(G, pairs=[(1, 2), (2, 3)])) - {(2, 3): 0, (1, 2): 1} - - Notes - ----- - Only defined on non-null directed acyclic graphs. - - Uses the $O(n^3)$ ancestor-list algorithm from: - M. A. Bender, M. Farach-Colton, G. Pemmasani, S. Skiena, P. Sumazin. - "Lowest common ancestors in trees and directed acyclic graphs." - Journal of Algorithms, 57(2): 75-94, 2005. - - See Also - -------- - tree_all_pairs_lowest_common_ancestor - lowest_common_ancestor - """ - if not nx.is_directed_acyclic_graph(G): - raise nx.NetworkXError("LCA only defined on directed acyclic graphs.") - if len(G) == 0: - raise nx.NetworkXPointlessConcept("LCA meaningless on null graphs.") - - # The copy isn't ideal, neither is the switch-on-type, but without it users - # passing an iterable will encounter confusing errors, and itertools.tee - # does not appear to handle builtin types efficiently (IE, it materializes - # another buffer rather than just creating listoperators at the same - # offset). The Python documentation notes use of tee is unadvised when one - # is consumed before the other. - # - # This will always produce correct results and avoid unnecessary - # copies in many common cases. - # - if not isinstance(pairs, (Mapping, Set)) and pairs is not None: - pairs = set(pairs) - - # Convert G into a dag with a single root by adding a node with edges to - # all sources iff necessary. - sources = [n for n, deg in G.in_degree if deg == 0] - if len(sources) == 1: - root = sources[0] - super_root = None - else: - G = G.copy() - # find unused node - root = -1 - while root in G: - root -= 1 - # use that as the super_root below all sources - super_root = root - for source in sources: - G.add_edge(root, source) - - # Start by computing a spanning tree, and the DAG of all edges not in it. - # We will then use the tree lca algorithm on the spanning tree, and use - # the DAG to figure out the set of tree queries necessary. - spanning_tree = nx.dfs_tree(G, root) - dag = nx.DiGraph( - (u, v) - for u, v in G.edges - if u not in spanning_tree or v not in spanning_tree[u] - ) - - # Ensure that both the dag and the spanning tree contains all nodes in G, - # even nodes that are disconnected in the dag. - spanning_tree.add_nodes_from(G) - dag.add_nodes_from(G) - - counter = count() - - # Necessary to handle graphs consisting of a single node and no edges. - root_distance = {root: next(counter)} - - for edge in nx.bfs_edges(spanning_tree, root): - for node in edge: - if node not in root_distance: - root_distance[node] = next(counter) - - # Index the position of all nodes in the Euler tour so we can efficiently - # sort lists and merge in tour order. - euler_tour_pos = {} - for node in nx.depth_first_search.dfs_preorder_nodes(G, root): - if node not in euler_tour_pos: - euler_tour_pos[node] = next(counter) - - # Generate the set of all nodes of interest in the pairs. - pairset = set() - if pairs is not None: - pairset = set(chain.from_iterable(pairs)) - - for n in pairset: - if n not in G: - msg = f"The node {str(n)} is not in the digraph." - raise nx.NodeNotFound(msg) - - # Generate the transitive closure over the dag (not G) of all nodes, and - # sort each node's closure set by order of first appearance in the Euler - # tour. - ancestors = {} - for v in dag: - if pairs is None or v in pairset: - my_ancestors = nx.ancestors(G, v) - my_ancestors.add(v) - ancestors[v] = sorted(my_ancestors, key=euler_tour_pos.get) - - def _compute_dag_lca_from_tree_values(tree_lca, dry_run): - """Iterate through the in-order merge for each pair of interest. - - We do this to answer the user's query, but it is also used to - avoid generating unnecessary tree entries when the user only - needs some pairs. - """ - for (node1, node2) in pairs if pairs is not None else tree_lca: - best_root_distance = None - best = None - - indices = [0, 0] - ancestors_by_index = [ancestors[node1], ancestors[node2]] - - def get_next_in_merged_lists(indices): - """Returns index of the list containing the next item - - Next order refers to the merged order. - Index can be 0 or 1 (or None if exhausted). - """ - index1, index2 = indices - if index1 >= len(ancestors[node1]) and index2 >= len(ancestors[node2]): - return None - elif index1 >= len(ancestors[node1]): - return 1 - elif index2 >= len(ancestors[node2]): - return 0 - elif ( - euler_tour_pos[ancestors[node1][index1]] - < euler_tour_pos[ancestors[node2][index2]] - ): - return 0 - else: - return 1 - - # Find the LCA by iterating through the in-order merge of the two - # nodes of interests' ancestor sets. In principle, we need to - # consider all pairs in the Cartesian product of the ancestor sets, - # but by the restricted min range query reduction we are guaranteed - # that one of the pairs of interest is adjacent in the merged list - # iff one came from each list. - i = get_next_in_merged_lists(indices) - cur = ancestors_by_index[i][indices[i]], i - while i is not None: - prev = cur - indices[i] += 1 - i = get_next_in_merged_lists(indices) - if i is not None: - cur = ancestors_by_index[i][indices[i]], i - - # Two adjacent entries must not be from the same list - # in order for their tree LCA to be considered. - if cur[1] != prev[1]: - tree_node1, tree_node2 = prev[0], cur[0] - if (tree_node1, tree_node2) in tree_lca: - ans = tree_lca[tree_node1, tree_node2] - else: - ans = tree_lca[tree_node2, tree_node1] - if not dry_run and ( - best is None or root_distance[ans] > best_root_distance - ): - best_root_distance = root_distance[ans] - best = ans - - # If the LCA is super_root, there is no LCA in the user's graph. - if not dry_run and (super_root is None or best != super_root): - yield (node1, node2), best - - # Generate the spanning tree lca for all pairs. This doesn't make sense to - # do incrementally since we are using a linear time offline algorithm for - # tree lca. - if pairs is None: - # We want all pairs so we'll need the entire tree. - tree_lca = dict(tree_all_pairs_lowest_common_ancestor(spanning_tree, root)) - else: - # We only need the merged adjacent pairs by seeing which queries the - # algorithm needs then generating them in a single pass. - tree_lca = defaultdict(int) - for _ in _compute_dag_lca_from_tree_values(tree_lca, True): - pass - - # Replace the bogus default tree values with the real ones. - for (pair, lca) in tree_all_pairs_lowest_common_ancestor( - spanning_tree, root, tree_lca - ): - tree_lca[pair] = lca - - # All precomputations complete. Now we just need to give the user the pairs - # they asked for, or all pairs if they want them all. - return _compute_dag_lca_from_tree_values(tree_lca, False) diff --git a/networkx/algorithms/tests/test_lowest_common_ancestors.py b/networkx/algorithms/tests/test_lowest_common_ancestors.py index be3b9fe5..02b63371 100644 --- a/networkx/algorithms/tests/test_lowest_common_ancestors.py +++ b/networkx/algorithms/tests/test_lowest_common_ancestors.py @@ -6,8 +6,6 @@ import networkx as nx tree_all_pairs_lca = nx.tree_all_pairs_lowest_common_ancestor all_pairs_lca = nx.all_pairs_lowest_common_ancestor -naive_lca = nx.naive_lowest_common_ancestor -naive_all_pairs_lca = nx.naive_all_pairs_lowest_common_ancestor def get_pair(dictionary, n1, n2): @@ -193,7 +191,7 @@ class TestDAGLCA: (2, 6): 6, (2, 7): 7, (2, 8): 7, - (3, 3): 8, + (3, 3): 3, (3, 4): 4, (3, 5): 5, (3, 6): 6, @@ -314,86 +312,6 @@ class TestDAGLCA: G.add_node(3) assert nx.lowest_common_ancestor(G, 3, 3) == 3 - -class TestNaiveLCA: - @classmethod - def setup_class(cls): - cls.DG = nx.DiGraph() - cls.DG.add_nodes_from(range(5)) - cls.DG.add_edges_from([(1, 0), (2, 0), (3, 2), (4, 1), (4, 3)]) - - cls.root_distance = nx.shortest_path_length(cls.DG, source=4) - - cls.gold = { - (0, 0): 0, - (0, 1): 1, - (0, 2): 2, - (0, 3): 3, - (0, 4): 4, - (1, 1): 1, - (1, 2): 4, - (1, 3): 4, - (1, 4): 4, - (2, 2): 2, - (2, 3): 3, - (2, 4): 4, - (3, 3): 3, - (3, 4): 4, - (4, 4): 4, - } - - def assert_lca_dicts_same(self, d1, d2, G=None): - """Checks if d1 and d2 contain the same pairs and - have a node at the same distance from root for each. - If G is None use self.DG.""" - if G is None: - G = self.DG - root_distance = self.root_distance - else: - roots = [n for n, deg in G.in_degree if deg == 0] - assert len(roots) == 1 - root_distance = nx.shortest_path_length(G, source=roots[0]) - - for a, b in ((min(pair), max(pair)) for pair in chain(d1, d2)): - assert ( - root_distance[get_pair(d1, a, b)] == root_distance[get_pair(d2, a, b)] - ) - - def test_naive_all_pairs_lowest_common_ancestor1(self): - """Produces the correct results.""" - self.assert_lca_dicts_same(dict(naive_all_pairs_lca(self.DG)), self.gold) - - def test_naive_all_pairs_lowest_common_ancestor2(self): - """Produces the correct results when all pairs given.""" - all_pairs = list(product(self.DG.nodes(), self.DG.nodes())) - ans = naive_all_pairs_lca(self.DG, pairs=all_pairs) - self.assert_lca_dicts_same(dict(ans), self.gold) - - def test_naive_all_pairs_lowest_common_ancestor3(self): - """Produces the correct results when all pairs given as a generator.""" - all_pairs = product(self.DG.nodes(), self.DG.nodes()) - ans = naive_all_pairs_lca(self.DG, pairs=all_pairs) - self.assert_lca_dicts_same(dict(ans), self.gold) - - def test_naive_all_pairs_lowest_common_ancestor4(self): - """Test that LCA on null graph bails.""" - with pytest.raises(nx.NetworkXPointlessConcept): - gen = naive_all_pairs_lca(nx.DiGraph()) - next(gen) - - def test_naive_all_pairs_lowest_common_ancestor5(self): - """Test that LCA on non-dags bails.""" - with pytest.raises(nx.NetworkXError): - gen = naive_all_pairs_lca(nx.DiGraph([(3, 4), (4, 3)])) - next(gen) - - def test_naive_all_pairs_lowest_common_ancestor6(self): - """Test that pairs with no LCA specified emits nothing.""" - G = self.DG.copy() - G.add_node(-1) - gen = naive_all_pairs_lca(G, [(-1, -1), (-1, 0)]) - assert dict(gen) == {(-1, -1): -1} - def test_naive_lowest_common_ancestor1(self): """Test that the one-pair function works for issue #4574.""" G = nx.DiGraph() @@ -419,7 +337,7 @@ class TestNaiveLCA: ] ) - assert naive_lca(G, 7, 9) == None + assert nx.lowest_common_ancestor(G, 7, 9) == None def test_naive_lowest_common_ancestor2(self): """Test that the one-pair function works for issue #4942.""" @@ -430,4 +348,4 @@ class TestNaiveLCA: G.add_edge(4, 0) G.add_edge(5, 2) - assert naive_lca(G, 1, 3) == 2 + assert nx.lowest_common_ancestor(G, 1, 3) == 2 |