summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJarrod Millman <jarrod.millman@gmail.com>2021-09-08 16:25:45 -0700
committerJarrod Millman <jarrod.millman@gmail.com>2021-09-09 15:02:03 -0700
commit9a9f0d93a3aa57abbc8697c9dac9a395d1954ee9 (patch)
treeb6a96e5343cd3f114529df52a5f14bc4e820837c
parent36369bad70dbf5cc5b52dcd175d1a60b50d29e37 (diff)
downloadnetworkx-9a9f0d93a3aa57abbc8697c9dac9a395d1954ee9.tar.gz
Fix modularity functions (#5072)
* CI/MAINT: drop gdal tests (#5068) * Unpin gdal. * Try pinning to 3.3 * Pin setuptools instead. * remove gdal from workflows * modularity_max: provide labels to get_edge_data (#4965) * modularity_max: provide labels to get_edge_data * test greedy mod communities relabeled separately * Minor style changes + add note to test. Co-authored-by: Mathilde Leval <mleval@csod.com> Co-authored-by: Ross Barnowski <rossbar@berkeley.edu> * Improvements to greedy_modularity_community (#4996) * fix docstring importation at naive_greedy_modularity_communities * add resolution at docstring parameters at modularity_max.py * use weight arg instead of 'weight' key at greedy_modularity_communities() * separate test for non contiguous integers as node-labels * modularity_max: breaking the loop when given community size is reached (#4950) * modularity_max: allow input of desired number of communities * import warnings * format * format * improvements according to discussion * try to manually merge main + resolve conflicts * add test for n_communities parameter using circular ladder graph * style of test * greedy_modularity_communities with digraphs and multi(di)graphs (#5007) (#5007) * refactor N & m calculation @ greedy_modularity_communities() * add Newman 'Analysis of weighted networks' @ References * extend greedy_modularity_communities to DiGraph's * separate data structures init to a new function * remove unused 'merges' list * add @not_implemented_for('directed', 'multigraph') above naive_greedy_modularity_communities() * add tests for greed_modularity_communities() with directed & directed+weighted * use nx.all_neighbors() to access successors as well as predecessors at DiGraph's * extend greedy_modularity_communities() to MultiGraph's * extend greedy_modularity_communities() to MultiDiGraph's * refactor: remove encoder/decoder dicts (node-labels are already hashable) * b pulls data from in_degree isntead of out_degree * match the sequence of the return values with the docstring reference * test: modify existing Graphs instead of creating new ones * dq correction for multi-edges explanation & other minor edits * CNM -> Clauset-Newman-Moore & isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)) -> G.is_MultiGraph() * amend @not_implemented_for decorator @ naive_greedy_modularity_communities() * Allow greedy_modularity_communities to use floating point weights or resolution (#5065) * revise mapped_queue to separate priority from element * update max_modularity to use new mapped_queue * change attribute names h, d to heap, position in MappedQueue * clean up initialization of data structures and handling of q0 * change i,j,k notation to u,v,w (no indexes since gh-5007) * Update networkx/utils/mapped_queue.py Co-authored-by: Ross Barnowski <rossbar@berkeley.edu> Co-authored-by: Ross Barnowski <rossbar@berkeley.edu> Co-authored-by: Ross Barnowski <rossbar@berkeley.edu> Co-authored-by: Mathilde LĂ©val <9384853+mathilde-leval@users.noreply.github.com> Co-authored-by: Mathilde Leval <mleval@csod.com> Co-authored-by: Thanasis Mattas <thanasismatt@gmail.com> Co-authored-by: Martha Frysztacki <martha.frysztacki@kit.edu> Co-authored-by: Thanasis Mattas <atmattas@physics.auth.gr> Co-authored-by: Dan Schult <dschult@colgate.edu>
-rw-r--r--.github/workflows/coverage.yml5
-rw-r--r--.github/workflows/deploy-docs.yml5
-rw-r--r--networkx/algorithms/community/modularity_max.py293
-rw-r--r--networkx/algorithms/community/quality.py6
-rw-r--r--networkx/algorithms/community/tests/test_modularity_max.py248
-rw-r--r--networkx/utils/mapped_queue.py284
-rw-r--r--networkx/utils/tests/test_mapped_queue.py127
7 files changed, 694 insertions, 274 deletions
diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml
index cd9120d7..eb216dce 100644
--- a/.github/workflows/coverage.yml
+++ b/.github/workflows/coverage.yml
@@ -22,16 +22,13 @@ jobs:
- name: Before install
run: |
sudo apt-get update
- sudo apt-get install libgdal-dev graphviz graphviz-dev
+ sudo apt-get install graphviz graphviz-dev
- name: Install packages
run: |
pip install --upgrade pip wheel setuptools
pip install -r requirements/default.txt -r requirements/test.txt
pip install -r requirements/extra.txt
- export CPLUS_INCLUDE_PATH=/usr/include/gdal
- export C_INCLUDE_PATH=/usr/include/gdal
- pip install gdal==3.0.4
pip install .
pip list
diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml
index 20463943..746b73d5 100644
--- a/.github/workflows/deploy-docs.yml
+++ b/.github/workflows/deploy-docs.yml
@@ -22,7 +22,7 @@ jobs:
- name: Before install
run: |
sudo apt-get update
- sudo apt-get install libgdal-dev graphviz graphviz-dev
+ sudo apt-get install graphviz graphviz-dev
sudo apt-get install texlive texlive-latex-extra latexmk texlive-xetex
sudo apt-get install fonts-freefont-otf xindy
sudo apt-get install libspatialindex-dev
@@ -34,9 +34,6 @@ jobs:
pip install -r requirements/extra.txt
pip install -r requirements/example.txt
pip install -U -r requirements/doc.txt
- export CPLUS_INCLUDE_PATH=/usr/include/gdal
- export C_INCLUDE_PATH=/usr/include/gdal
- pip install gdal==3.0.4
pip install .
pip list
diff --git a/networkx/algorithms/community/modularity_max.py b/networkx/algorithms/community/modularity_max.py
index cc9d5d98..10d42e7f 100644
--- a/networkx/algorithms/community/modularity_max.py
+++ b/networkx/algorithms/community/modularity_max.py
@@ -1,8 +1,11 @@
"""Functions for detecting communities based on modularity."""
-from networkx.algorithms.community.quality import modularity
+from collections import defaultdict
+import networkx as nx
+from networkx.algorithms.community.quality import modularity
from networkx.utils.mapped_queue import MappedQueue
+from networkx.utils import not_implemented_for
__all__ = [
"greedy_modularity_communities",
@@ -11,15 +14,14 @@ __all__ = [
]
-def greedy_modularity_communities(G, weight=None, resolution=1):
+def greedy_modularity_communities(G, weight=None, resolution=1, n_communities=1):
r"""Find communities in G using greedy modularity maximization.
This function uses Clauset-Newman-Moore greedy modularity maximization [2]_.
- This method currently supports the Graph class.
Greedy modularity maximization begins with each node in its own community
and joins the pair of communities that most increases modularity until no
- such pair exists.
+ such pair exists or until number of communities `n_communities` is reached.
This function maximizes the generalized modularity, where `resolution`
is the resolution parameter, often expressed as $\gamma$.
@@ -28,22 +30,35 @@ def greedy_modularity_communities(G, weight=None, resolution=1):
Parameters
----------
G : NetworkX graph
+
weight : string or None, optional (default=None)
- The name of an edge attribute that holds the numerical value used
- as a weight. If None, then each edge has weight 1.
- The degree is the sum of the edge weights adjacent to the node.
+ The name of an edge attribute that holds the numerical value used
+ as a weight. If None, then each edge has weight 1.
+ The degree is the sum of the edge weights adjacent to the node.
+
+ resolution : float (default=1)
+ If resolution is less than 1, modularity favors larger communities.
+ Greater than 1 favors smaller communities.
+
+ n_communities: int
+ Desired number of communities: the community merging process is
+ terminated once this number of communities is reached, or until
+ modularity can not be further increased. Must be between 1 and the
+ total number of nodes in `G`. Default is ``1``, meaning the community
+ merging process continues until all nodes are in the same community
+ or until the best community structure is found.
Returns
-------
- list
- A list of sets of nodes, one for each community.
+ partition: list
+ A list of frozensets of nodes, one for each community.
Sorted by length with largest communities first.
Examples
--------
>>> from networkx.algorithms.community import greedy_modularity_communities
>>> G = nx.karate_club_graph()
- >>> c = list(greedy_modularity_communities(G))
+ >>> c = greedy_modularity_communities(G)
>>> sorted(c[0])
[8, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]
@@ -53,176 +68,177 @@ def greedy_modularity_communities(G, weight=None, resolution=1):
References
----------
- .. [1] M. E. J Newman "Networks: An Introduction", page 224
+ .. [1] Newman, M. E. J. "Networks: An Introduction", page 224
Oxford University Press 2011.
.. [2] Clauset, A., Newman, M. E., & Moore, C.
"Finding community structure in very large networks."
Physical Review E 70(6), 2004.
.. [3] Reichardt and Bornholdt "Statistical Mechanics of Community
Detection" Phys. Rev. E74, 2006.
+ .. [4] Newman, M. E. J."Analysis of weighted networks"
+ Physical Review E 70(5 Pt 2):056131, 2004.
"""
-
- # Count nodes and edges
- N = len(G.nodes())
- m = sum([d.get("weight", 1) for u, v, d in G.edges(data=True)])
- q0 = 1.0 / (2.0 * m)
-
- # Map node labels to contiguous integers
- label_for_node = {i: v for i, v in enumerate(G.nodes())}
- node_for_label = {label_for_node[i]: i for i in range(N)}
-
- # Calculate degrees
- k_for_label = G.degree(G.nodes(), weight=weight)
- k = [k_for_label[label_for_node[i]] for i in range(N)]
-
- # Initialize community and merge lists
- communities = {i: frozenset([i]) for i in range(N)}
- merges = []
-
- # Initial modularity
- partition = [[label_for_node[x] for x in c] for c in communities.values()]
- q_cnm = modularity(G, partition, resolution=resolution)
-
- # Initialize data structures
- # CNM Eq 8-9 (Eq 8 was missing a factor of 2 (from A_ij + A_ji)
- # a[i]: fraction of edges within community i
- # dq_dict[i][j]: dQ for merging community i, j
- # dq_heap[i][n] : (-dq, i, j) for communitiy i nth largest dQ
- # H[n]: (-dq, i, j) for community with nth largest max_j(dQ_ij)
- a = [k[i] * q0 for i in range(N)]
- dq_dict = {
- i: {
- j: 2 * q0 * G.get_edge_data(i, j).get(weight, 1.0)
- - 2 * resolution * k[i] * k[j] * q0 * q0
- for j in [node_for_label[u] for u in G.neighbors(label_for_node[i])]
- if j != i
- }
- for i in range(N)
- }
- dq_heap = [
- MappedQueue([(-dq, i, j) for j, dq in dq_dict[i].items()]) for i in range(N)
- ]
- H = MappedQueue([dq_heap[i].h[0] for i in range(N) if len(dq_heap[i]) > 0])
-
- # Merge communities until we can't improve modularity
- while len(H) > 1:
+ directed = G.is_directed()
+ N = G.number_of_nodes()
+ if (n_communities < 1) or (n_communities > N):
+ raise ValueError(
+ f"n_communities must be between 1 and {N}. Got {n_communities}"
+ )
+
+ # Count edges (or the sum of edge-weights for weighted graphs)
+ m = G.size(weight)
+ q0 = 1 / m
+
+ # Calculate degrees (notation from the papers)
+ # a : the fraction of (weighted) out-degree for each node
+ # b : the fraction of (weighted) in-degree for each node
+ if directed:
+ a = {node: deg_out * q0 for node, deg_out in G.out_degree(weight=weight)}
+ b = {node: deg_in * q0 for node, deg_in in G.in_degree(weight=weight)}
+ else:
+ a = b = {node: deg * q0 * 0.5 for node, deg in G.degree(weight=weight)}
+
+ # this preliminary step collects the edge weights for each node pair
+ # It handles multigraph and digraph and works fine for graph.
+ dq_dict = defaultdict(lambda: defaultdict(float))
+ for u, v, wt in G.edges(data=weight, default=1):
+ if u == v:
+ continue
+ dq_dict[u][v] += wt
+ dq_dict[v][u] += wt
+
+ # now scale and subtract the expected edge-weights term
+ for u, nbrdict in dq_dict.items():
+ for v, wt in nbrdict.items():
+ dq_dict[u][v] = q0 * wt - resolution * (a[u] * b[v] + b[u] * a[v])
+
+ # Use -dq to get a max_heap instead of a min_heap
+ # dq_heap holds a heap for each node's neighbors
+ dq_heap = {u: MappedQueue({(u, v): -dq for v, dq in dq_dict[u].items()}) for u in G}
+ # H -> all_dq_heap holds a heap with the best items for each node
+ H = MappedQueue([dq_heap[n].heap[0] for n in G if len(dq_heap[n]) > 0])
+
+ # Initialize single-node communities
+ communities = {n: frozenset([n]) for n in G}
+
+ # Merge communities until we can't improve modularity or until desired number of
+ # communities (n_communities) is reached.
+ while len(H) > n_communities:
# Find best merge
# Remove from heap of row maxes
# Ties will be broken by choosing the pair with lowest min community id
try:
- dq, i, j = H.pop()
+ negdq, u, v = H.pop()
except IndexError:
break
- dq = -dq
- # Remove best merge from row i heap
- dq_heap[i].pop()
+ dq = -negdq
+ # Remove best merge from row u heap
+ dq_heap[u].pop()
# Push new row max onto H
- if len(dq_heap[i]) > 0:
- H.push(dq_heap[i].h[0])
- # If this element was also at the root of row j, we need to remove the
+ if len(dq_heap[u]) > 0:
+ H.push(dq_heap[u].heap[0])
+ # If this element was also at the root of row v, we need to remove the
# duplicate entry from H
- if dq_heap[j].h[0] == (-dq, j, i):
- H.remove((-dq, j, i))
- # Remove best merge from row j heap
- dq_heap[j].remove((-dq, j, i))
+ if dq_heap[v].heap[0] == (v, u):
+ H.remove((v, u))
+ # Remove best merge from row v heap
+ dq_heap[v].remove((v, u))
# Push new row max onto H
- if len(dq_heap[j]) > 0:
- H.push(dq_heap[j].h[0])
+ if len(dq_heap[v]) > 0:
+ H.push(dq_heap[v].heap[0])
else:
- # Duplicate wasn't in H, just remove from row j heap
- dq_heap[j].remove((-dq, j, i))
- # Stop when change is non-positive
+ # Duplicate wasn't in H, just remove from row v heap
+ dq_heap[v].remove((v, u))
+ # Stop when change is non-positive (no improvement possible)
if dq <= 0:
break
# Perform merge
- communities[j] = frozenset(communities[i] | communities[j])
- del communities[i]
- merges.append((i, j, dq))
- # New modularity
- q_cnm += dq
- # Get list of communities connected to merged communities
- i_set = set(dq_dict[i].keys())
- j_set = set(dq_dict[j].keys())
- all_set = (i_set | j_set) - {i, j}
- both_set = i_set & j_set
- # Merge i into j and update dQ
- for k in all_set:
+ communities[v] = frozenset(communities[u] | communities[v])
+ del communities[u]
+
+ # Get neighbor communities connected to the merged communities
+ u_nbrs = set(dq_dict[u])
+ v_nbrs = set(dq_dict[v])
+ all_nbrs = (u_nbrs | v_nbrs) - {u, v}
+ both_nbrs = u_nbrs & v_nbrs
+ # Update dq for merge of u into v
+ for w in all_nbrs:
# Calculate new dq value
- if k in both_set:
- dq_jk = dq_dict[j][k] + dq_dict[i][k]
- elif k in j_set:
- dq_jk = dq_dict[j][k] - 2.0 * resolution * a[i] * a[k]
- else:
- # k in i_set
- dq_jk = dq_dict[i][k] - 2.0 * resolution * a[j] * a[k]
- # Update rows j and k
- for row, col in [(j, k), (k, j)]:
- # Save old value for finding heap index
- if k in j_set:
- d_old = (-dq_dict[row][col], row, col)
- else:
- d_old = None
- # Update dict for j,k only (i is removed below)
- dq_dict[row][col] = dq_jk
+ if w in both_nbrs:
+ dq_vw = dq_dict[v][w] + dq_dict[u][w]
+ elif w in v_nbrs:
+ dq_vw = dq_dict[v][w] - resolution * (a[u] * b[w] + a[w] * b[u])
+ else: # w in u_nbrs
+ dq_vw = dq_dict[u][w] - resolution * (a[v] * b[w] + a[w] * b[v])
+ # Update rows v and w
+ for row, col in [(v, w), (w, v)]:
+ dq_heap_row = dq_heap[row]
+ # Update dict for v,w only (u is removed below)
+ dq_dict[row][col] = dq_vw
# Save old max of per-row heap
- if len(dq_heap[row]) > 0:
- d_oldmax = dq_heap[row].h[0]
+ if len(dq_heap_row) > 0:
+ d_oldmax = dq_heap_row.heap[0]
else:
d_oldmax = None
# Add/update heaps
- d = (-dq_jk, row, col)
- if d_old is None:
- # We're creating a new nonzero element, add to heap
- dq_heap[row].push(d)
- else:
+ d = (row, col)
+ d_negdq = -dq_vw
+ # Save old value for finding heap index
+ if w in v_nbrs:
# Update existing element in per-row heap
- dq_heap[row].update(d_old, d)
+ dq_heap_row.update(d, d, priority=d_negdq)
+ else:
+ # We're creating a new nonzero element, add to heap
+ dq_heap_row.push(d, priority=d_negdq)
# Update heap of row maxes if necessary
if d_oldmax is None:
# No entries previously in this row, push new max
- H.push(d)
+ H.push(d, priority=d_negdq)
else:
# We've updated an entry in this row, has the max changed?
- if dq_heap[row].h[0] != d_oldmax:
- H.update(d_oldmax, dq_heap[row].h[0])
+ row_max = dq_heap_row.heap[0]
+ if d_oldmax != row_max or d_oldmax.priority != row_max.priority:
+ H.update(d_oldmax, row_max)
- # Remove row/col i from matrix
- i_neighbors = dq_dict[i].keys()
- for k in i_neighbors:
+ # Remove row/col u from dq_dict matrix
+ for w in dq_dict[u]:
# Remove from dict
- dq_old = dq_dict[k][i]
- del dq_dict[k][i]
+ dq_old = dq_dict[w][u]
+ del dq_dict[w][u]
# Remove from heaps if we haven't already
- if k != j:
+ if w != v:
# Remove both row and column
- for row, col in [(k, i), (i, k)]:
+ for row, col in [(w, u), (u, w)]:
+ dq_heap_row = dq_heap[row]
# Check if replaced dq is row max
- d_old = (-dq_old, row, col)
- if dq_heap[row].h[0] == d_old:
+ d_old = (row, col)
+ if dq_heap_row.heap[0] == d_old:
# Update per-row heap and heap of row maxes
- dq_heap[row].remove(d_old)
+ dq_heap_row.remove(d_old)
H.remove(d_old)
# Update row max
- if len(dq_heap[row]) > 0:
- H.push(dq_heap[row].h[0])
+ if len(dq_heap_row) > 0:
+ H.push(dq_heap_row.heap[0])
else:
# Only update per-row heap
- dq_heap[row].remove(d_old)
+ dq_heap_row.remove(d_old)
- del dq_dict[i]
- # Mark row i as deleted, but keep placeholder
- dq_heap[i] = MappedQueue()
- # Merge i into j and update a
- a[j] += a[i]
- a[i] = 0
+ del dq_dict[u]
+ # Mark row u as deleted, but keep placeholder
+ dq_heap[u] = MappedQueue()
+ # Merge u into v and update a
+ a[v] += a[u]
+ a[u] = 0
+ if directed:
+ b[v] += b[u]
+ b[u] = 0
- communities = [
- frozenset([label_for_node[i] for i in c]) for c in communities.values()
- ]
- return sorted(communities, key=len, reverse=True)
+ return sorted(communities.values(), key=len, reverse=True)
+@not_implemented_for("directed")
+@not_implemented_for("multigraph")
def naive_greedy_modularity_communities(G, resolution=1):
r"""Find communities in G using greedy modularity maximization.
@@ -241,6 +257,10 @@ def naive_greedy_modularity_communities(G, resolution=1):
----------
G : NetworkX graph
+ resolution : float (default=1)
+ If resolution is less than 1, modularity favors larger communities.
+ Greater than 1 favors smaller communities.
+
Returns
-------
list
@@ -249,9 +269,10 @@ def naive_greedy_modularity_communities(G, resolution=1):
Examples
--------
- >>> from networkx.algorithms.community import greedy_modularity_communities
+ >>> from networkx.algorithms.community import \
+ ... naive_greedy_modularity_communities
>>> G = nx.karate_club_graph()
- >>> c = list(greedy_modularity_communities(G))
+ >>> c = naive_greedy_modularity_communities(G)
>>> sorted(c[0])
[8, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]
diff --git a/networkx/algorithms/community/quality.py b/networkx/algorithms/community/quality.py
index 9f374435..ce79727b 100644
--- a/networkx/algorithms/community/quality.py
+++ b/networkx/algorithms/community/quality.py
@@ -283,9 +283,9 @@ def modularity(G, communities, weight="weight", resolution=1):
These node sets must represent a partition of G's nodes.
weight : string or None, optional (default="weight")
- The edge attribute that holds the numerical value used
- as a weight. If None or an edge does not have that attribute,
- then that edge has weight 1.
+ The edge attribute that holds the numerical value used
+ as a weight. If None or an edge does not have that attribute,
+ then that edge has weight 1.
resolution : float (default=1)
If resolution is less than 1, modularity favors larger communities.
diff --git a/networkx/algorithms/community/tests/test_modularity_max.py b/networkx/algorithms/community/tests/test_modularity_max.py
index 23103d4b..433ca746 100644
--- a/networkx/algorithms/community/tests/test_modularity_max.py
+++ b/networkx/algorithms/community/tests/test_modularity_max.py
@@ -3,7 +3,6 @@ import pytest
import networkx as nx
from networkx.algorithms.community import (
greedy_modularity_communities,
- modularity,
naive_greedy_modularity_communities,
)
@@ -13,17 +12,74 @@ from networkx.algorithms.community import (
)
def test_modularity_communities(func):
G = nx.karate_club_graph()
-
john_a = frozenset(
[8, 14, 15, 18, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]
)
mr_hi = frozenset([0, 4, 5, 6, 10, 11, 16, 19])
overlap = frozenset([1, 2, 3, 7, 9, 12, 13, 17, 21])
expected = {john_a, overlap, mr_hi}
+ assert set(func(G)) == expected
+
+@pytest.mark.parametrize(
+ "func", (greedy_modularity_communities, naive_greedy_modularity_communities)
+)
+def test_modularity_communities_categorical_labels(func):
+ # Using other than 0-starting contiguous integers as node-labels.
+ G = nx.Graph(
+ [
+ ("a", "b"),
+ ("a", "c"),
+ ("b", "c"),
+ ("b", "d"), # inter-community edge
+ ("d", "e"),
+ ("d", "f"),
+ ("d", "g"),
+ ("f", "g"),
+ ("d", "e"),
+ ("f", "e"),
+ ]
+ )
+ expected = {frozenset({"f", "g", "e", "d"}), frozenset({"a", "b", "c"})}
assert set(func(G)) == expected
+def test_greedy_modularity_communities_relabeled():
+ # Test for gh-4966
+ G = nx.balanced_tree(2, 2)
+ mapping = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e", 5: "f", 6: "g", 7: "h"}
+ G = nx.relabel_nodes(G, mapping)
+ expected = [frozenset({"e", "d", "a", "b"}), frozenset({"c", "f", "g"})]
+ assert greedy_modularity_communities(G) == expected
+
+
+def test_greedy_modularity_communities_directed():
+ G = nx.DiGraph(
+ [
+ ("a", "b"),
+ ("a", "c"),
+ ("b", "c"),
+ ("b", "d"), # inter-community edge
+ ("d", "e"),
+ ("d", "f"),
+ ("d", "g"),
+ ("f", "g"),
+ ("d", "e"),
+ ("f", "e"),
+ ]
+ )
+ expected = [frozenset({"f", "g", "e", "d"}), frozenset({"a", "b", "c"})]
+ assert greedy_modularity_communities(G) == expected
+
+ # with loops
+ G = nx.DiGraph()
+ G.add_edges_from(
+ [(1, 1), (1, 2), (1, 3), (2, 3), (1, 4), (4, 4), (5, 5), (4, 5), (4, 6), (5, 6)]
+ )
+ expected = [frozenset({1, 2, 3}), frozenset({4, 5, 6})]
+ assert greedy_modularity_communities(G) == expected
+
+
def test_modularity_communities_weighted():
G = nx.balanced_tree(2, 3)
for (a, b) in G.edges:
@@ -35,6 +91,178 @@ def test_modularity_communities_weighted():
expected = [{0, 1, 3, 4, 7, 8, 9, 10}, {2, 5, 6, 11, 12, 13, 14}]
assert greedy_modularity_communities(G, weight="weight") == expected
+ assert greedy_modularity_communities(G, weight="weight", resolution=0.9) == expected
+ assert greedy_modularity_communities(G, weight="weight", resolution=0.3) == expected
+ assert greedy_modularity_communities(G, weight="weight", resolution=1.1) != expected
+
+
+def test_modularity_communities_floating_point():
+ # check for floating point error when used as key in the mapped_queue dict.
+ # Test for gh-4992 and gh-5000
+ G = nx.Graph()
+ G.add_weighted_edges_from(
+ [(0, 1, 12), (1, 4, 71), (2, 3, 15), (2, 4, 10), (3, 6, 13)]
+ )
+ expected = [{0, 1, 4}, {2, 3, 6}]
+ assert greedy_modularity_communities(G, weight="weight") == expected
+ assert (
+ greedy_modularity_communities(G, weight="weight", resolution=0.99) == expected
+ )
+
+
+def test_modularity_communities_directed_weighted():
+ G = nx.DiGraph()
+ G.add_weighted_edges_from(
+ [
+ (1, 2, 5),
+ (1, 3, 3),
+ (2, 3, 6),
+ (2, 6, 1),
+ (1, 4, 1),
+ (4, 5, 3),
+ (4, 6, 7),
+ (5, 6, 2),
+ (5, 7, 5),
+ (5, 8, 4),
+ (6, 8, 3),
+ ]
+ )
+ expected = [frozenset({4, 5, 6, 7, 8}), frozenset({1, 2, 3})]
+ assert greedy_modularity_communities(G, weight="weight") == expected
+
+ # A large weight of the edge (2, 6) causes 6 to change group, even if it shares
+ # only one connection with the new group and 3 with the old one.
+ G[2][6]["weight"] = 20
+ expected = [frozenset({1, 2, 3, 6}), frozenset({4, 5, 7, 8})]
+ assert greedy_modularity_communities(G, weight="weight") == expected
+
+
+def test_greedy_modularity_communities_multigraph():
+ G = nx.MultiGraph()
+ G.add_edges_from(
+ [
+ (1, 2),
+ (1, 2),
+ (1, 3),
+ (2, 3),
+ (1, 4),
+ (2, 4),
+ (4, 5),
+ (5, 6),
+ (5, 7),
+ (5, 7),
+ (6, 7),
+ (7, 8),
+ (5, 8),
+ ]
+ )
+ expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})]
+ assert greedy_modularity_communities(G) == expected
+
+ # Converting (4, 5) into a multi-edge causes node 4 to change group.
+ G.add_edge(4, 5)
+ expected = [frozenset({4, 5, 6, 7, 8}), frozenset({1, 2, 3})]
+ assert greedy_modularity_communities(G) == expected
+
+
+def test_greedy_modularity_communities_multigraph_weighted():
+ G = nx.MultiGraph()
+ G.add_weighted_edges_from(
+ [
+ (1, 2, 5),
+ (1, 2, 3),
+ (1, 3, 6),
+ (1, 3, 6),
+ (2, 3, 4),
+ (1, 4, 1),
+ (1, 4, 1),
+ (2, 4, 3),
+ (2, 4, 3),
+ (4, 5, 1),
+ (5, 6, 3),
+ (5, 6, 7),
+ (5, 6, 4),
+ (5, 7, 9),
+ (5, 7, 9),
+ (6, 7, 8),
+ (7, 8, 2),
+ (7, 8, 2),
+ (5, 8, 6),
+ (5, 8, 6),
+ ]
+ )
+ expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})]
+ assert greedy_modularity_communities(G, weight="weight") == expected
+
+ # Adding multi-edge (4, 5, 16) causes node 4 to change group.
+ G.add_edge(4, 5, weight=16)
+ expected = [frozenset({4, 5, 6, 7, 8}), frozenset({1, 2, 3})]
+ assert greedy_modularity_communities(G, weight="weight") == expected
+
+ # Increasing the weight of edge (1, 4) causes node 4 to return to the former group.
+ G[1][4][1]["weight"] = 3
+ expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})]
+ assert greedy_modularity_communities(G, weight="weight") == expected
+
+
+def test_greed_modularity_communities_multidigraph():
+ G = nx.MultiDiGraph()
+ G.add_edges_from(
+ [
+ (1, 2),
+ (1, 2),
+ (3, 1),
+ (2, 3),
+ (2, 3),
+ (3, 2),
+ (1, 4),
+ (2, 4),
+ (4, 2),
+ (4, 5),
+ (5, 6),
+ (5, 6),
+ (6, 5),
+ (5, 7),
+ (6, 7),
+ (7, 8),
+ (5, 8),
+ (8, 4),
+ ]
+ )
+ expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})]
+ assert greedy_modularity_communities(G, weight="weight") == expected
+
+
+def test_greed_modularity_communities_multidigraph_weighted():
+ G = nx.MultiDiGraph()
+ G.add_weighted_edges_from(
+ [
+ (1, 2, 5),
+ (1, 2, 3),
+ (3, 1, 6),
+ (1, 3, 6),
+ (3, 2, 4),
+ (1, 4, 2),
+ (1, 4, 5),
+ (2, 4, 3),
+ (3, 2, 8),
+ (4, 2, 3),
+ (4, 3, 5),
+ (4, 5, 2),
+ (5, 6, 3),
+ (5, 6, 7),
+ (6, 5, 4),
+ (5, 7, 9),
+ (5, 7, 9),
+ (7, 6, 8),
+ (7, 8, 2),
+ (8, 7, 2),
+ (5, 8, 6),
+ (5, 8, 6),
+ ]
+ )
+ expected = [frozenset({1, 2, 3, 4}), frozenset({5, 6, 7, 8})]
+ assert greedy_modularity_communities(G, weight="weight") == expected
def test_resolution_parameter_impact():
@@ -54,3 +282,19 @@ def test_resolution_parameter_impact():
expected = [frozenset(range(8)), frozenset(range(8, 13))]
assert greedy_modularity_communities(G, resolution=gamma) == expected
assert naive_greedy_modularity_communities(G, resolution=gamma) == expected
+
+
+def test_n_communities_parameter():
+ G = nx.circular_ladder_graph(4)
+
+ # No aggregation:
+ expected = [{k} for k in range(8)]
+ assert greedy_modularity_communities(G, n_communities=8) == expected
+
+ # Aggregation to half order (number of nodes)
+ expected = [{k, k + 1} for k in range(0, 8, 2)]
+ assert greedy_modularity_communities(G, n_communities=4) == expected
+
+ # Default aggregation case (here, 2 communities emerge)
+ expected = [frozenset(range(0, 4)), frozenset(range(4, 8))]
+ assert greedy_modularity_communities(G, n_communities=1) == expected
diff --git a/networkx/utils/mapped_queue.py b/networkx/utils/mapped_queue.py
index 5888348e..0ff53a0b 100644
--- a/networkx/utils/mapped_queue.py
+++ b/networkx/utils/mapped_queue.py
@@ -6,16 +6,92 @@ import heapq
__all__ = ["MappedQueue"]
+class _HeapElement:
+ """This proxy class separates the heap element from its priority.
+
+ The idea is that using a 2-tuple (priority, element) works
+ for sorting, but not for dict lookup because priorities are
+ often floating point values so round-off can mess up equality.
+
+ So, we need inequalities to look at the priority (for sorting)
+ and equality (and hash) to look at the element to enable
+ updates to the priority.
+
+ Unfortunately, this class can be tricky to work with if you forget that
+ `__lt__` compares the priority while `__eq__` compares the element.
+ In `greedy_modularity_communities()` the following code is
+ used to check that two _HeapElements differ in either element or priority:
+
+ if d_oldmax != row_max or d_oldmax.priority != row_max.priority:
+
+ If the priorities are the same, this implementation uses the element
+ as a tiebreaker. This provides compatibility with older systems that
+ use tuples to combine priority and elements.
+ """
+
+ __slots__ = ["priority", "element", "_hash"]
+
+ def __init__(self, priority, element):
+ self.priority = priority
+ self.element = element
+ self._hash = hash(element)
+
+ def __lt__(self, other):
+ try:
+ other_priority = other.priority
+ except AttributeError:
+ return self.priority < other
+ # assume comparing to another _HeapElement
+ if self.priority == other_priority:
+ return self.element < other.element
+ return self.priority < other_priority
+
+ def __gt__(self, other):
+ try:
+ other_priority = other.priority
+ except AttributeError:
+ return self.priority > other
+ # assume comparing to another _HeapElement
+ if self.priority == other_priority:
+ return self.element < other.element
+ return self.priority > other_priority
+
+ def __eq__(self, other):
+ try:
+ return self.element == other.element
+ except AttributeError:
+ return self.element == other
+
+ def __hash__(self):
+ return self._hash
+
+ def __getitem__(self, indx):
+ return self.priority if indx == 0 else self.element[indx - 1]
+
+ def __iter__(self):
+ yield self.priority
+ try:
+ yield from self.element
+ except TypeError:
+ yield self.element
+
+ def __repr__(self):
+ return f"_HeapElement({self.priority}, {self.element})"
+
+
class MappedQueue:
- """The MappedQueue class implements an efficient minimum heap. The
- smallest element can be popped in O(1) time, new elements can be pushed
- in O(log n) time, and any element can be removed or updated in O(log n)
- time. The queue cannot contain duplicate elements and an attempt to push an
- element already in the queue will have no effect.
+ """The MappedQueue class implements a min-heap with removal and update-priority.
+
+ The min heap uses heapq as well as custom written _siftup and _siftdown
+ methods to allow the heap positions to be tracked by an additional dict
+ keyed by element to position. The smallest element can be popped in O(1) time,
+ new elements can be pushed in O(log n) time, and any element can be removed
+ or updated in O(log n) time. The queue cannot contain duplicate elements
+ and an attempt to push an element already in the queue will have no effect.
MappedQueue complements the heapq package from the python standard
library. While MappedQueue is designed for maximum compatibility with
- heapq, it has slightly different functionality.
+ heapq, it adds element removal, lookup, and priority update.
Examples
--------
@@ -27,8 +103,7 @@ class MappedQueue:
>>> q = MappedQueue([916, 50, 4609, 493, 237])
>>> q.push(1310)
True
- >>> x = [q.pop() for i in range(len(q.h))]
- >>> x
+ >>> [q.pop() for i in range(len(q.heap))]
[50, 237, 493, 916, 1310, 4609]
Elements can also be updated or removed from anywhere in the queue.
@@ -36,8 +111,7 @@ class MappedQueue:
>>> q = MappedQueue([916, 50, 4609, 493, 237])
>>> q.remove(493)
>>> q.update(237, 1117)
- >>> x = [q.pop() for i in range(len(q.h))]
- >>> x
+ >>> [q.pop() for i in range(len(q.heap))]
[50, 916, 1117, 4609]
References
@@ -50,132 +124,144 @@ class MappedQueue:
def __init__(self, data=[]):
"""Priority queue class with updatable priorities."""
- self.h = list(data)
- self.d = dict()
+ if isinstance(data, dict):
+ self.heap = [_HeapElement(v, k) for k, v in data.items()]
+ else:
+ self.heap = list(data)
+ self.position = dict()
self._heapify()
- def __len__(self):
- return len(self.h)
-
def _heapify(self):
"""Restore heap invariant and recalculate map."""
- heapq.heapify(self.h)
- self.d = {elt: pos for pos, elt in enumerate(self.h)}
- if len(self.h) != len(self.d):
+ heapq.heapify(self.heap)
+ self.position = {elt: pos for pos, elt in enumerate(self.heap)}
+ if len(self.heap) != len(self.position):
raise AssertionError("Heap contains duplicate elements")
- def push(self, elt):
+ def __len__(self):
+ return len(self.heap)
+
+ def push(self, elt, priority=None):
"""Add an element to the queue."""
+ if priority is not None:
+ elt = _HeapElement(priority, elt)
# If element is already in queue, do nothing
- if elt in self.d:
+ if elt in self.position:
return False
# Add element to heap and dict
- pos = len(self.h)
- self.h.append(elt)
- self.d[elt] = pos
+ pos = len(self.heap)
+ self.heap.append(elt)
+ self.position[elt] = pos
# Restore invariant by sifting down
- self._siftdown(pos)
+ self._siftdown(0, pos)
return True
def pop(self):
"""Remove and return the smallest element in the queue."""
# Remove smallest element
- elt = self.h[0]
- del self.d[elt]
+ elt = self.heap[0]
+ del self.position[elt]
# If elt is last item, remove and return
- if len(self.h) == 1:
- self.h.pop()
+ if len(self.heap) == 1:
+ self.heap.pop()
return elt
# Replace root with last element
- last = self.h.pop()
- self.h[0] = last
- self.d[last] = 0
- # Restore invariant by sifting up, then down
- pos = self._siftup(0)
- self._siftdown(pos)
+ last = self.heap.pop()
+ self.heap[0] = last
+ self.position[last] = 0
+ # Restore invariant by sifting up
+ self._siftup(0)
# Return smallest element
return elt
- def update(self, elt, new):
+ def update(self, elt, new, priority=None):
"""Replace an element in the queue with a new one."""
+ if priority is not None:
+ new = _HeapElement(priority, new)
# Replace
- pos = self.d[elt]
- self.h[pos] = new
- del self.d[elt]
- self.d[new] = pos
- # Restore invariant by sifting up, then down
- pos = self._siftup(pos)
- self._siftdown(pos)
+ pos = self.position[elt]
+ self.heap[pos] = new
+ del self.position[elt]
+ self.position[new] = pos
+ # Restore invariant by sifting up
+ self._siftup(pos)
def remove(self, elt):
"""Remove an element from the queue."""
# Find and remove element
try:
- pos = self.d[elt]
- del self.d[elt]
+ pos = self.position[elt]
+ del self.position[elt]
except KeyError:
# Not in queue
raise
# If elt is last item, remove and return
- if pos == len(self.h) - 1:
- self.h.pop()
+ if pos == len(self.heap) - 1:
+ self.heap.pop()
return
# Replace elt with last element
- last = self.h.pop()
- self.h[pos] = last
- self.d[last] = pos
- # Restore invariant by sifting up, then down
- pos = self._siftup(pos)
- self._siftdown(pos)
+ last = self.heap.pop()
+ self.heap[pos] = last
+ self.position[last] = pos
+ # Restore invariant by sifting up
+ self._siftup(pos)
def _siftup(self, pos):
- """Move element at pos down to a leaf by repeatedly moving the smaller
- child up."""
- h, d = self.h, self.d
- elt = h[pos]
- # Continue until element is in a leaf
- end_pos = len(h)
- left_pos = (pos << 1) + 1
- while left_pos < end_pos:
- # Left child is guaranteed to exist by loop predicate
- left = h[left_pos]
- try:
- right_pos = left_pos + 1
- right = h[right_pos]
- # Out-of-place, swap with left unless right is smaller
- if right < left:
- h[pos], h[right_pos] = right, elt
- pos, right_pos = right_pos, pos
- d[elt], d[right] = pos, right_pos
- else:
- h[pos], h[left_pos] = left, elt
- pos, left_pos = left_pos, pos
- d[elt], d[left] = pos, left_pos
- except IndexError:
- # Left leaf is the end of the heap, swap
- h[pos], h[left_pos] = left, elt
- pos, left_pos = left_pos, pos
- d[elt], d[left] = pos, left_pos
- # Update left_pos
- left_pos = (pos << 1) + 1
- return pos
-
- def _siftdown(self, pos):
- """Restore invariant by repeatedly replacing out-of-place element with
- its parent."""
- h, d = self.h, self.d
- elt = h[pos]
- # Continue until element is at root
+ """Move smaller child up until hitting a leaf.
+
+ Built to mimic code for heapq._siftup
+ only updating position dict too.
+ """
+ heap, position = self.heap, self.position
+ end_pos = len(heap)
+ startpos = pos
+ newitem = heap[pos]
+ # Shift up the smaller child until hitting a leaf
+ child_pos = (pos << 1) + 1 # start with leftmost child position
+ while child_pos < end_pos:
+ # Set child_pos to index of smaller child.
+ child = heap[child_pos]
+ right_pos = child_pos + 1
+ if right_pos < end_pos:
+ right = heap[right_pos]
+ if not child < right:
+ child = right
+ child_pos = right_pos
+ # Move the smaller child up.
+ heap[pos] = child
+ position[child] = pos
+ pos = child_pos
+ child_pos = (pos << 1) + 1
+ # pos is a leaf position. Put newitem there, and bubble it up
+ # to its final resting place (by sifting its parents down).
while pos > 0:
parent_pos = (pos - 1) >> 1
- parent = h[parent_pos]
- if parent > elt:
- # Swap out-of-place element with parent
- h[parent_pos], h[pos] = elt, parent
- parent_pos, pos = pos, parent_pos
- d[elt] = pos
- d[parent] = parent_pos
- else:
- # Invariant is satisfied
+ parent = heap[parent_pos]
+ if not newitem < parent:
+ break
+ heap[pos] = parent
+ position[parent] = pos
+ pos = parent_pos
+ heap[pos] = newitem
+ position[newitem] = pos
+
+ def _siftdown(self, start_pos, pos):
+ """Restore invariant. keep swapping with parent until smaller.
+
+ Built to mimic code for heapq._siftdown
+ only updating position dict too.
+ """
+ heap, position = self.heap, self.position
+ newitem = heap[pos]
+ # Follow the path to the root, moving parents down until finding a place
+ # newitem fits.
+ while pos > start_pos:
+ parent_pos = (pos - 1) >> 1
+ parent = heap[parent_pos]
+ if not newitem < parent:
break
- return pos
+ heap[pos] = parent
+ position[parent] = pos
+ pos = parent_pos
+ heap[pos] = newitem
+ position[newitem] = pos
diff --git a/networkx/utils/tests/test_mapped_queue.py b/networkx/utils/tests/test_mapped_queue.py
index 78ea91ec..89e251d4 100644
--- a/networkx/utils/tests/test_mapped_queue.py
+++ b/networkx/utils/tests/test_mapped_queue.py
@@ -1,4 +1,41 @@
-from networkx.utils.mapped_queue import MappedQueue
+import pytest
+from networkx.utils.mapped_queue import _HeapElement, MappedQueue
+
+
+def test_HeapElement_gtlt():
+ bar = _HeapElement(1.1, "a")
+ foo = _HeapElement(1, "b")
+ assert foo < bar
+ assert bar > foo
+ assert foo < 1.1
+ assert 1 < bar
+
+
+def test_HeapElement_eq():
+ bar = _HeapElement(1.1, "a")
+ foo = _HeapElement(1, "a")
+ assert foo == bar
+ assert bar == foo
+ assert foo == "a"
+
+
+def test_HeapElement_iter():
+ foo = _HeapElement(1, "a")
+ bar = _HeapElement(1.1, (3, 2, 1))
+ assert list(foo) == [1, "a"]
+ assert list(bar) == [1.1, 3, 2, 1]
+
+
+def test_HeapElement_getitem():
+ foo = _HeapElement(1, "a")
+ bar = _HeapElement(1.1, (3, 2, 1))
+ assert foo[1] == "a"
+ assert foo[0] == 1
+ assert bar[0] == 1.1
+ assert bar[2] == 2
+ assert bar[3] == 1
+ pytest.raises(IndexError, bar.__getitem__, 4)
+ pytest.raises(IndexError, foo.__getitem__, 2)
class TestMappedQueue:
@@ -6,13 +43,12 @@ class TestMappedQueue:
pass
def _check_map(self, q):
- d = {elt: pos for pos, elt in enumerate(q.h)}
- assert d == q.d
+ assert q.position == {elt: pos for pos, elt in enumerate(q.heap)}
def _make_mapped_queue(self, h):
q = MappedQueue()
- q.h = h
- q.d = {elt: pos for pos, elt in enumerate(h)}
+ q.heap = h
+ q.position = {elt: pos for pos, elt in enumerate(h)}
return q
def test_heapify(self):
@@ -37,7 +73,7 @@ class TestMappedQueue:
h_sifted = [2]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftup_one_child(self):
@@ -45,7 +81,7 @@ class TestMappedQueue:
h_sifted = [0, 2]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftup_left_child(self):
@@ -53,7 +89,7 @@ class TestMappedQueue:
h_sifted = [0, 2, 1]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftup_right_child(self):
@@ -61,39 +97,39 @@ class TestMappedQueue:
h_sifted = [0, 1, 2]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftup_multiple(self):
h = [0, 1, 2, 4, 3, 5, 6]
- h_sifted = [1, 3, 2, 4, 0, 5, 6]
+ h_sifted = [0, 1, 2, 4, 3, 5, 6]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_leaf(self):
h = [2]
h_sifted = [2]
q = self._make_mapped_queue(h)
- q._siftdown(0)
- assert q.h == h_sifted
+ q._siftdown(0, 0)
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_single(self):
h = [1, 0]
h_sifted = [0, 1]
q = self._make_mapped_queue(h)
- q._siftdown(len(h) - 1)
- assert q.h == h_sifted
+ q._siftdown(0, len(h) - 1)
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_multiple(self):
h = [1, 2, 3, 4, 5, 6, 7, 0]
h_sifted = [0, 1, 3, 2, 5, 6, 7, 4]
q = self._make_mapped_queue(h)
- q._siftdown(len(h) - 1)
- assert q.h == h_sifted
+ q._siftdown(0, len(h) - 1)
+ assert q.heap == h_sifted
self._check_map(q)
def test_push(self):
@@ -102,7 +138,7 @@ class TestMappedQueue:
q = MappedQueue()
for elt in to_push:
q.push(elt)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_push_duplicate(self):
@@ -112,7 +148,7 @@ class TestMappedQueue:
for elt in to_push:
inserted = q.push(elt)
assert inserted
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
inserted = q.push(1)
assert not inserted
@@ -122,9 +158,7 @@ class TestMappedQueue:
h_sorted = sorted(h)
q = self._make_mapped_queue(h)
q._heapify()
- popped = []
- for elt in sorted(h):
- popped.append(q.pop())
+ popped = [q.pop() for _ in range(len(h))]
assert popped == h_sorted
self._check_map(q)
@@ -133,25 +167,66 @@ class TestMappedQueue:
h_removed = [0, 2, 1, 6, 4, 5]
q = self._make_mapped_queue(h)
removed = q.remove(3)
- assert q.h == h_removed
+ assert q.heap == h_removed
def test_remove_root(self):
h = [0, 2, 1, 6, 3, 5, 4]
h_removed = [1, 2, 4, 6, 3, 5]
q = self._make_mapped_queue(h)
removed = q.remove(0)
- assert q.h == h_removed
+ assert q.heap == h_removed
def test_update_leaf(self):
h = [0, 20, 10, 60, 30, 50, 40]
h_updated = [0, 15, 10, 60, 20, 50, 40]
q = self._make_mapped_queue(h)
removed = q.update(30, 15)
- assert q.h == h_updated
+ assert q.heap == h_updated
def test_update_root(self):
h = [0, 20, 10, 60, 30, 50, 40]
h_updated = [10, 20, 35, 60, 30, 50, 40]
q = self._make_mapped_queue(h)
removed = q.update(0, 35)
- assert q.h == h_updated
+ assert q.heap == h_updated
+
+
+class TestMappedDict(TestMappedQueue):
+ def _make_mapped_queue(self, h):
+ priority_dict = {elt: elt for elt in h}
+ return MappedQueue(priority_dict)
+
+ def test_push(self):
+ to_push = [6, 1, 4, 3, 2, 5, 0]
+ h_sifted = [0, 2, 1, 6, 3, 5, 4]
+ q = MappedQueue()
+ for elt in to_push:
+ q.push(elt, priority=elt)
+ assert q.heap == h_sifted
+ self._check_map(q)
+
+ def test_push_duplicate(self):
+ to_push = [2, 1, 0]
+ h_sifted = [0, 2, 1]
+ q = MappedQueue()
+ for elt in to_push:
+ inserted = q.push(elt, priority=elt)
+ assert inserted
+ assert q.heap == h_sifted
+ self._check_map(q)
+ inserted = q.push(1, priority=1)
+ assert not inserted
+
+ def test_update_leaf(self):
+ h = [0, 20, 10, 60, 30, 50, 40]
+ h_updated = [0, 15, 10, 60, 20, 50, 40]
+ q = self._make_mapped_queue(h)
+ removed = q.update(30, 15, priority=15)
+ assert q.heap == h_updated
+
+ def test_update_root(self):
+ h = [0, 20, 10, 60, 30, 50, 40]
+ h_updated = [10, 20, 35, 60, 30, 50, 40]
+ q = self._make_mapped_queue(h)
+ removed = q.update(0, 35, priority=35)
+ assert q.heap == h_updated