summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDan Schult <dschult@colgate.edu>2021-09-07 11:31:40 -0400
committerGitHub <noreply@github.com>2021-09-07 11:31:40 -0400
commitc9060dc933b7280597c3ba6d382784ed0ed45697 (patch)
tree2f276bff8950706dee44254bdcc29c11ad1b7955
parent59dc28a1297ec330d9bdaa0e018537615bb6750a (diff)
downloadnetworkx-c9060dc933b7280597c3ba6d382784ed0ed45697.tar.gz
Allow greedy_modularity_communities to use floating point weights or resolution (#5065)
* revise mapped_queue to separate priority from element * update max_modularity to use new mapped_queue * change attribute names h, d to heap, position in MappedQueue * clean up initialization of data structures and handling of q0 * change i,j,k notation to u,v,w (no indexes since gh-5007) * Update networkx/utils/mapped_queue.py Co-authored-by: Ross Barnowski <rossbar@berkeley.edu> Co-authored-by: Ross Barnowski <rossbar@berkeley.edu>
-rw-r--r--networkx/algorithms/community/modularity_max.py310
-rw-r--r--networkx/algorithms/community/tests/test_modularity_max.py17
-rw-r--r--networkx/utils/mapped_queue.py284
-rw-r--r--networkx/utils/tests/test_mapped_queue.py127
4 files changed, 415 insertions, 323 deletions
diff --git a/networkx/algorithms/community/modularity_max.py b/networkx/algorithms/community/modularity_max.py
index 10c5d01c..10d42e7f 100644
--- a/networkx/algorithms/community/modularity_max.py
+++ b/networkx/algorithms/community/modularity_max.py
@@ -1,6 +1,6 @@
"""Functions for detecting communities based on modularity."""
-from collections import Counter
+from collections import defaultdict
import networkx as nx
from networkx.algorithms.community.quality import modularity
@@ -14,110 +14,6 @@ __all__ = [
]
-def _greedy_modularity_communities_init(G, weight=None, resolution=1):
- r"""Initializes the data structures for greedy_modularity_communities().
-
- Clauset-Newman-Moore Eq 8-9. Eq 8 was missing a factor of 2 (from A_ij + A_ji).
- See [2]_ at :func:`greedy_modularity_communities`.
-
- Parameters
- ----------
- G : NetworkX graph
-
- weight : string or None, optional (default=None)
- The name of an edge attribute that holds the numerical value used
- as a weight. If None, then each edge has weight 1.
- The degree is the sum of the edge weights adjacent to the node.
-
- resolution : float (default=1)
- If resolution is less than 1, modularity favors larger communities.
- Greater than 1 favors smaller communities.
-
- Returns
- -------
- dq_dict : dict of dict's
- dq_dict[i][j]: dQ for merging community i, j
-
- dq_heap : dict of MappedQueue's
- dq_heap[i][n] : (-dq, i, j) for communitiy i nth largest dQ
-
- H : MappedQueue
- (-dq, i, j) for community with nth largest max_j(dQ_ij)
-
- a, b : dict
- undirected:
- a[i]: fraction of (total weight of) edges within community i
- b : None
- directed:
- a[i]: fraction of (total weight of) edges with tails within community i
- b[i]: fraction of (total weight of) edges with heads within community i
-
- See Also
- --------
- :func:`greedy_modularity_communities`
- :func:`~networkx.algorithms.community.quality.modularity`
- """
- # Count nodes and edges (or the sum of edge-weights for weighted graphs)
- N = G.number_of_nodes()
- m = G.size(weight)
-
- # Calculate degrees
- if G.is_directed():
- k_in = dict(G.in_degree(weight=weight))
- k_out = dict(G.out_degree(weight=weight))
- q0 = 1.0 / m
- else:
- k_in = k_out = dict(G.degree(weight=weight))
- q0 = 1.0 / (2.0 * m)
-
- a = {node: kout * q0 for node, kout in k_out.items()}
- if G.is_directed():
- b = {node: kin * q0 for node, kin in k_in.items()}
- else:
- b = None
-
- dq_dict = {
- i: {
- j: q0
- * (
- G.get_edge_data(i, j, default={weight: 0}).get(weight, 1.0)
- + G.get_edge_data(j, i, default={weight: 0}).get(weight, 1.0)
- - resolution * q0 * (k_out[i] * k_in[j] + k_in[i] * k_out[j])
- )
- for j in nx.all_neighbors(G, i)
- if j != i
- }
- for i in G.nodes()
- }
-
- # dq correction for multi-edges
- # In case of multi-edges, get_edge_data(i, j) returns the key: data dict of the i, j
- # edges, which does not have a 'weight' key. Therefore, when calculating dq for i, j
- # Aij is always 1.0 and a correction is required.
- if G.is_multigraph():
- edges_count = dict(Counter(G.edges()))
- multi_edges = [edge for edge, count in edges_count.items() if count > 1]
- for edge in multi_edges:
- total_wt = sum(d.get(weight, 1) for d in G.get_edge_data(*edge).values())
- if G.is_directed():
- # The correction applies only to the direction of the edge. The edge at
- # the other direction is either not a multiedge (where the weight is
- # added correctly), non-existent or it is also a multiedge, in which
- # case it will be handled singly when its turn in the loop comes.
- q00 = q0
- else:
- q00 = 2 * q0
- dq_dict[edge[0]][edge[1]] += q00 * (total_wt - 1)
- dq_dict[edge[1]][edge[0]] += q00 * (total_wt - 1)
-
- dq_heap = {
- i: MappedQueue([(-dq, i, j) for j, dq in dq_dict[i].items()]) for i in G.nodes()
- }
- H = MappedQueue([dq_heap[i].h[0] for i in G.nodes() if len(dq_heap[i]) > 0])
-
- return dq_dict, dq_heap, H, a, b
-
-
def greedy_modularity_communities(G, weight=None, resolution=1, n_communities=1):
r"""Find communities in G using greedy modularity maximization.
@@ -182,20 +78,48 @@ def greedy_modularity_communities(G, weight=None, resolution=1, n_communities=1)
.. [4] Newman, M. E. J."Analysis of weighted networks"
Physical Review E 70(5 Pt 2):056131, 2004.
"""
+ directed = G.is_directed()
N = G.number_of_nodes()
if (n_communities < 1) or (n_communities > N):
raise ValueError(
f"n_communities must be between 1 and {N}. Got {n_communities}"
)
- # Initialize data structures
- dq_dict, dq_heap, H, a, b = _greedy_modularity_communities_init(
- G, weight, resolution
- )
+ # Count edges (or the sum of edge-weights for weighted graphs)
+ m = G.size(weight)
+ q0 = 1 / m
+
+ # Calculate degrees (notation from the papers)
+ # a : the fraction of (weighted) out-degree for each node
+ # b : the fraction of (weighted) in-degree for each node
+ if directed:
+ a = {node: deg_out * q0 for node, deg_out in G.out_degree(weight=weight)}
+ b = {node: deg_in * q0 for node, deg_in in G.in_degree(weight=weight)}
+ else:
+ a = b = {node: deg * q0 * 0.5 for node, deg in G.degree(weight=weight)}
+
+ # this preliminary step collects the edge weights for each node pair
+ # It handles multigraph and digraph and works fine for graph.
+ dq_dict = defaultdict(lambda: defaultdict(float))
+ for u, v, wt in G.edges(data=weight, default=1):
+ if u == v:
+ continue
+ dq_dict[u][v] += wt
+ dq_dict[v][u] += wt
+
+ # now scale and subtract the expected edge-weights term
+ for u, nbrdict in dq_dict.items():
+ for v, wt in nbrdict.items():
+ dq_dict[u][v] = q0 * wt - resolution * (a[u] * b[v] + b[u] * a[v])
+
+ # Use -dq to get a max_heap instead of a min_heap
+ # dq_heap holds a heap for each node's neighbors
+ dq_heap = {u: MappedQueue({(u, v): -dq for v, dq in dq_dict[u].items()}) for u in G}
+ # H -> all_dq_heap holds a heap with the best items for each node
+ H = MappedQueue([dq_heap[n].heap[0] for n in G if len(dq_heap[n]) > 0])
+
# Initialize single-node communities
- communities = {i: frozenset([i]) for i in G.nodes()}
- # Initial modularity
- q_cnm = modularity(G, communities.values(), resolution=resolution)
+ communities = {n: frozenset([n]) for n in G}
# Merge communities until we can't improve modularity or until desired number of
# communities (n_communities) is reached.
@@ -204,123 +128,113 @@ def greedy_modularity_communities(G, weight=None, resolution=1, n_communities=1)
# Remove from heap of row maxes
# Ties will be broken by choosing the pair with lowest min community id
try:
- dq, i, j = H.pop()
+ negdq, u, v = H.pop()
except IndexError:
break
- dq = -dq
- # Remove best merge from row i heap
- dq_heap[i].pop()
+ dq = -negdq
+ # Remove best merge from row u heap
+ dq_heap[u].pop()
# Push new row max onto H
- if len(dq_heap[i]) > 0:
- H.push(dq_heap[i].h[0])
- # If this element was also at the root of row j, we need to remove the
+ if len(dq_heap[u]) > 0:
+ H.push(dq_heap[u].heap[0])
+ # If this element was also at the root of row v, we need to remove the
# duplicate entry from H
- if dq_heap[j].h[0] == (-dq, j, i):
- H.remove((-dq, j, i))
- # Remove best merge from row j heap
- dq_heap[j].remove((-dq, j, i))
+ if dq_heap[v].heap[0] == (v, u):
+ H.remove((v, u))
+ # Remove best merge from row v heap
+ dq_heap[v].remove((v, u))
# Push new row max onto H
- if len(dq_heap[j]) > 0:
- H.push(dq_heap[j].h[0])
+ if len(dq_heap[v]) > 0:
+ H.push(dq_heap[v].heap[0])
else:
- # Duplicate wasn't in H, just remove from row j heap
- dq_heap[j].remove((-dq, j, i))
- # Stop when change is non-positive
+ # Duplicate wasn't in H, just remove from row v heap
+ dq_heap[v].remove((v, u))
+ # Stop when change is non-positive (no improvement possible)
if dq <= 0:
break
# Perform merge
- communities[j] = frozenset(communities[i] | communities[j])
- del communities[i]
- # New modularity
- q_cnm += dq
- # Get list of communities connected to merged communities
- i_set = set(dq_dict[i].keys())
- j_set = set(dq_dict[j].keys())
- all_set = (i_set | j_set) - {i, j}
- both_set = i_set & j_set
- # Merge i into j and update dQ
- for k in all_set:
+ communities[v] = frozenset(communities[u] | communities[v])
+ del communities[u]
+
+ # Get neighbor communities connected to the merged communities
+ u_nbrs = set(dq_dict[u])
+ v_nbrs = set(dq_dict[v])
+ all_nbrs = (u_nbrs | v_nbrs) - {u, v}
+ both_nbrs = u_nbrs & v_nbrs
+ # Update dq for merge of u into v
+ for w in all_nbrs:
# Calculate new dq value
- if k in both_set:
- dq_jk = dq_dict[j][k] + dq_dict[i][k]
- elif k in j_set:
- if G.is_directed():
- dq_jk = dq_dict[j][k] - resolution * (a[i] * b[k] + a[k] * b[i])
- else:
- dq_jk = dq_dict[j][k] - 2.0 * resolution * a[i] * a[k]
- else:
- # k in i_set
- if G.is_directed():
- dq_jk = dq_dict[i][k] - resolution * (a[j] * b[k] + a[k] * b[j])
- else:
- dq_jk = dq_dict[i][k] - 2.0 * resolution * a[j] * a[k]
- # Update rows j and k
- for row, col in [(j, k), (k, j)]:
- # Save old value for finding heap index
- if k in j_set:
- d_old = (-dq_dict[row][col], row, col)
- else:
- d_old = None
- # Update dict for j,k only (i is removed below)
- dq_dict[row][col] = dq_jk
+ if w in both_nbrs:
+ dq_vw = dq_dict[v][w] + dq_dict[u][w]
+ elif w in v_nbrs:
+ dq_vw = dq_dict[v][w] - resolution * (a[u] * b[w] + a[w] * b[u])
+ else: # w in u_nbrs
+ dq_vw = dq_dict[u][w] - resolution * (a[v] * b[w] + a[w] * b[v])
+ # Update rows v and w
+ for row, col in [(v, w), (w, v)]:
+ dq_heap_row = dq_heap[row]
+ # Update dict for v,w only (u is removed below)
+ dq_dict[row][col] = dq_vw
# Save old max of per-row heap
- if len(dq_heap[row]) > 0:
- d_oldmax = dq_heap[row].h[0]
+ if len(dq_heap_row) > 0:
+ d_oldmax = dq_heap_row.heap[0]
else:
d_oldmax = None
# Add/update heaps
- d = (-dq_jk, row, col)
- if d_old is None:
- # We're creating a new nonzero element, add to heap
- dq_heap[row].push(d)
- else:
+ d = (row, col)
+ d_negdq = -dq_vw
+ # Save old value for finding heap index
+ if w in v_nbrs:
# Update existing element in per-row heap
- dq_heap[row].update(d_old, d)
+ dq_heap_row.update(d, d, priority=d_negdq)
+ else:
+ # We're creating a new nonzero element, add to heap
+ dq_heap_row.push(d, priority=d_negdq)
# Update heap of row maxes if necessary
if d_oldmax is None:
# No entries previously in this row, push new max
- H.push(d)
+ H.push(d, priority=d_negdq)
else:
# We've updated an entry in this row, has the max changed?
- if dq_heap[row].h[0] != d_oldmax:
- H.update(d_oldmax, dq_heap[row].h[0])
+ row_max = dq_heap_row.heap[0]
+ if d_oldmax != row_max or d_oldmax.priority != row_max.priority:
+ H.update(d_oldmax, row_max)
- # Remove row/col i from matrix
- i_neighbors = dq_dict[i].keys()
- for k in i_neighbors:
+ # Remove row/col u from dq_dict matrix
+ for w in dq_dict[u]:
# Remove from dict
- dq_old = dq_dict[k][i]
- del dq_dict[k][i]
+ dq_old = dq_dict[w][u]
+ del dq_dict[w][u]
# Remove from heaps if we haven't already
- if k != j:
+ if w != v:
# Remove both row and column
- for row, col in [(k, i), (i, k)]:
+ for row, col in [(w, u), (u, w)]:
+ dq_heap_row = dq_heap[row]
# Check if replaced dq is row max
- d_old = (-dq_old, row, col)
- if dq_heap[row].h[0] == d_old:
+ d_old = (row, col)
+ if dq_heap_row.heap[0] == d_old:
# Update per-row heap and heap of row maxes
- dq_heap[row].remove(d_old)
+ dq_heap_row.remove(d_old)
H.remove(d_old)
# Update row max
- if len(dq_heap[row]) > 0:
- H.push(dq_heap[row].h[0])
+ if len(dq_heap_row) > 0:
+ H.push(dq_heap_row.heap[0])
else:
# Only update per-row heap
- dq_heap[row].remove(d_old)
-
- del dq_dict[i]
- # Mark row i as deleted, but keep placeholder
- dq_heap[i] = MappedQueue()
- # Merge i into j and update a
- a[j] += a[i]
- a[i] = 0
- if G.is_directed():
- b[j] += b[i]
- b[i] = 0
-
- partition = sorted(communities.values(), key=len, reverse=True)
- return partition
+ dq_heap_row.remove(d_old)
+
+ del dq_dict[u]
+ # Mark row u as deleted, but keep placeholder
+ dq_heap[u] = MappedQueue()
+ # Merge u into v and update a
+ a[v] += a[u]
+ a[u] = 0
+ if directed:
+ b[v] += b[u]
+ b[u] = 0
+
+ return sorted(communities.values(), key=len, reverse=True)
@not_implemented_for("directed")
diff --git a/networkx/algorithms/community/tests/test_modularity_max.py b/networkx/algorithms/community/tests/test_modularity_max.py
index 3bae8528..433ca746 100644
--- a/networkx/algorithms/community/tests/test_modularity_max.py
+++ b/networkx/algorithms/community/tests/test_modularity_max.py
@@ -91,6 +91,23 @@ def test_modularity_communities_weighted():
expected = [{0, 1, 3, 4, 7, 8, 9, 10}, {2, 5, 6, 11, 12, 13, 14}]
assert greedy_modularity_communities(G, weight="weight") == expected
+ assert greedy_modularity_communities(G, weight="weight", resolution=0.9) == expected
+ assert greedy_modularity_communities(G, weight="weight", resolution=0.3) == expected
+ assert greedy_modularity_communities(G, weight="weight", resolution=1.1) != expected
+
+
+def test_modularity_communities_floating_point():
+ # check for floating point error when used as key in the mapped_queue dict.
+ # Test for gh-4992 and gh-5000
+ G = nx.Graph()
+ G.add_weighted_edges_from(
+ [(0, 1, 12), (1, 4, 71), (2, 3, 15), (2, 4, 10), (3, 6, 13)]
+ )
+ expected = [{0, 1, 4}, {2, 3, 6}]
+ assert greedy_modularity_communities(G, weight="weight") == expected
+ assert (
+ greedy_modularity_communities(G, weight="weight", resolution=0.99) == expected
+ )
def test_modularity_communities_directed_weighted():
diff --git a/networkx/utils/mapped_queue.py b/networkx/utils/mapped_queue.py
index 5888348e..0ff53a0b 100644
--- a/networkx/utils/mapped_queue.py
+++ b/networkx/utils/mapped_queue.py
@@ -6,16 +6,92 @@ import heapq
__all__ = ["MappedQueue"]
+class _HeapElement:
+ """This proxy class separates the heap element from its priority.
+
+ The idea is that using a 2-tuple (priority, element) works
+ for sorting, but not for dict lookup because priorities are
+ often floating point values so round-off can mess up equality.
+
+ So, we need inequalities to look at the priority (for sorting)
+ and equality (and hash) to look at the element to enable
+ updates to the priority.
+
+ Unfortunately, this class can be tricky to work with if you forget that
+ `__lt__` compares the priority while `__eq__` compares the element.
+ In `greedy_modularity_communities()` the following code is
+ used to check that two _HeapElements differ in either element or priority:
+
+ if d_oldmax != row_max or d_oldmax.priority != row_max.priority:
+
+ If the priorities are the same, this implementation uses the element
+ as a tiebreaker. This provides compatibility with older systems that
+ use tuples to combine priority and elements.
+ """
+
+ __slots__ = ["priority", "element", "_hash"]
+
+ def __init__(self, priority, element):
+ self.priority = priority
+ self.element = element
+ self._hash = hash(element)
+
+ def __lt__(self, other):
+ try:
+ other_priority = other.priority
+ except AttributeError:
+ return self.priority < other
+ # assume comparing to another _HeapElement
+ if self.priority == other_priority:
+ return self.element < other.element
+ return self.priority < other_priority
+
+ def __gt__(self, other):
+ try:
+ other_priority = other.priority
+ except AttributeError:
+ return self.priority > other
+ # assume comparing to another _HeapElement
+ if self.priority == other_priority:
+ return self.element < other.element
+ return self.priority > other_priority
+
+ def __eq__(self, other):
+ try:
+ return self.element == other.element
+ except AttributeError:
+ return self.element == other
+
+ def __hash__(self):
+ return self._hash
+
+ def __getitem__(self, indx):
+ return self.priority if indx == 0 else self.element[indx - 1]
+
+ def __iter__(self):
+ yield self.priority
+ try:
+ yield from self.element
+ except TypeError:
+ yield self.element
+
+ def __repr__(self):
+ return f"_HeapElement({self.priority}, {self.element})"
+
+
class MappedQueue:
- """The MappedQueue class implements an efficient minimum heap. The
- smallest element can be popped in O(1) time, new elements can be pushed
- in O(log n) time, and any element can be removed or updated in O(log n)
- time. The queue cannot contain duplicate elements and an attempt to push an
- element already in the queue will have no effect.
+ """The MappedQueue class implements a min-heap with removal and update-priority.
+
+ The min heap uses heapq as well as custom written _siftup and _siftdown
+ methods to allow the heap positions to be tracked by an additional dict
+ keyed by element to position. The smallest element can be popped in O(1) time,
+ new elements can be pushed in O(log n) time, and any element can be removed
+ or updated in O(log n) time. The queue cannot contain duplicate elements
+ and an attempt to push an element already in the queue will have no effect.
MappedQueue complements the heapq package from the python standard
library. While MappedQueue is designed for maximum compatibility with
- heapq, it has slightly different functionality.
+ heapq, it adds element removal, lookup, and priority update.
Examples
--------
@@ -27,8 +103,7 @@ class MappedQueue:
>>> q = MappedQueue([916, 50, 4609, 493, 237])
>>> q.push(1310)
True
- >>> x = [q.pop() for i in range(len(q.h))]
- >>> x
+ >>> [q.pop() for i in range(len(q.heap))]
[50, 237, 493, 916, 1310, 4609]
Elements can also be updated or removed from anywhere in the queue.
@@ -36,8 +111,7 @@ class MappedQueue:
>>> q = MappedQueue([916, 50, 4609, 493, 237])
>>> q.remove(493)
>>> q.update(237, 1117)
- >>> x = [q.pop() for i in range(len(q.h))]
- >>> x
+ >>> [q.pop() for i in range(len(q.heap))]
[50, 916, 1117, 4609]
References
@@ -50,132 +124,144 @@ class MappedQueue:
def __init__(self, data=[]):
"""Priority queue class with updatable priorities."""
- self.h = list(data)
- self.d = dict()
+ if isinstance(data, dict):
+ self.heap = [_HeapElement(v, k) for k, v in data.items()]
+ else:
+ self.heap = list(data)
+ self.position = dict()
self._heapify()
- def __len__(self):
- return len(self.h)
-
def _heapify(self):
"""Restore heap invariant and recalculate map."""
- heapq.heapify(self.h)
- self.d = {elt: pos for pos, elt in enumerate(self.h)}
- if len(self.h) != len(self.d):
+ heapq.heapify(self.heap)
+ self.position = {elt: pos for pos, elt in enumerate(self.heap)}
+ if len(self.heap) != len(self.position):
raise AssertionError("Heap contains duplicate elements")
- def push(self, elt):
+ def __len__(self):
+ return len(self.heap)
+
+ def push(self, elt, priority=None):
"""Add an element to the queue."""
+ if priority is not None:
+ elt = _HeapElement(priority, elt)
# If element is already in queue, do nothing
- if elt in self.d:
+ if elt in self.position:
return False
# Add element to heap and dict
- pos = len(self.h)
- self.h.append(elt)
- self.d[elt] = pos
+ pos = len(self.heap)
+ self.heap.append(elt)
+ self.position[elt] = pos
# Restore invariant by sifting down
- self._siftdown(pos)
+ self._siftdown(0, pos)
return True
def pop(self):
"""Remove and return the smallest element in the queue."""
# Remove smallest element
- elt = self.h[0]
- del self.d[elt]
+ elt = self.heap[0]
+ del self.position[elt]
# If elt is last item, remove and return
- if len(self.h) == 1:
- self.h.pop()
+ if len(self.heap) == 1:
+ self.heap.pop()
return elt
# Replace root with last element
- last = self.h.pop()
- self.h[0] = last
- self.d[last] = 0
- # Restore invariant by sifting up, then down
- pos = self._siftup(0)
- self._siftdown(pos)
+ last = self.heap.pop()
+ self.heap[0] = last
+ self.position[last] = 0
+ # Restore invariant by sifting up
+ self._siftup(0)
# Return smallest element
return elt
- def update(self, elt, new):
+ def update(self, elt, new, priority=None):
"""Replace an element in the queue with a new one."""
+ if priority is not None:
+ new = _HeapElement(priority, new)
# Replace
- pos = self.d[elt]
- self.h[pos] = new
- del self.d[elt]
- self.d[new] = pos
- # Restore invariant by sifting up, then down
- pos = self._siftup(pos)
- self._siftdown(pos)
+ pos = self.position[elt]
+ self.heap[pos] = new
+ del self.position[elt]
+ self.position[new] = pos
+ # Restore invariant by sifting up
+ self._siftup(pos)
def remove(self, elt):
"""Remove an element from the queue."""
# Find and remove element
try:
- pos = self.d[elt]
- del self.d[elt]
+ pos = self.position[elt]
+ del self.position[elt]
except KeyError:
# Not in queue
raise
# If elt is last item, remove and return
- if pos == len(self.h) - 1:
- self.h.pop()
+ if pos == len(self.heap) - 1:
+ self.heap.pop()
return
# Replace elt with last element
- last = self.h.pop()
- self.h[pos] = last
- self.d[last] = pos
- # Restore invariant by sifting up, then down
- pos = self._siftup(pos)
- self._siftdown(pos)
+ last = self.heap.pop()
+ self.heap[pos] = last
+ self.position[last] = pos
+ # Restore invariant by sifting up
+ self._siftup(pos)
def _siftup(self, pos):
- """Move element at pos down to a leaf by repeatedly moving the smaller
- child up."""
- h, d = self.h, self.d
- elt = h[pos]
- # Continue until element is in a leaf
- end_pos = len(h)
- left_pos = (pos << 1) + 1
- while left_pos < end_pos:
- # Left child is guaranteed to exist by loop predicate
- left = h[left_pos]
- try:
- right_pos = left_pos + 1
- right = h[right_pos]
- # Out-of-place, swap with left unless right is smaller
- if right < left:
- h[pos], h[right_pos] = right, elt
- pos, right_pos = right_pos, pos
- d[elt], d[right] = pos, right_pos
- else:
- h[pos], h[left_pos] = left, elt
- pos, left_pos = left_pos, pos
- d[elt], d[left] = pos, left_pos
- except IndexError:
- # Left leaf is the end of the heap, swap
- h[pos], h[left_pos] = left, elt
- pos, left_pos = left_pos, pos
- d[elt], d[left] = pos, left_pos
- # Update left_pos
- left_pos = (pos << 1) + 1
- return pos
-
- def _siftdown(self, pos):
- """Restore invariant by repeatedly replacing out-of-place element with
- its parent."""
- h, d = self.h, self.d
- elt = h[pos]
- # Continue until element is at root
+ """Move smaller child up until hitting a leaf.
+
+ Built to mimic code for heapq._siftup
+ only updating position dict too.
+ """
+ heap, position = self.heap, self.position
+ end_pos = len(heap)
+ startpos = pos
+ newitem = heap[pos]
+ # Shift up the smaller child until hitting a leaf
+ child_pos = (pos << 1) + 1 # start with leftmost child position
+ while child_pos < end_pos:
+ # Set child_pos to index of smaller child.
+ child = heap[child_pos]
+ right_pos = child_pos + 1
+ if right_pos < end_pos:
+ right = heap[right_pos]
+ if not child < right:
+ child = right
+ child_pos = right_pos
+ # Move the smaller child up.
+ heap[pos] = child
+ position[child] = pos
+ pos = child_pos
+ child_pos = (pos << 1) + 1
+ # pos is a leaf position. Put newitem there, and bubble it up
+ # to its final resting place (by sifting its parents down).
while pos > 0:
parent_pos = (pos - 1) >> 1
- parent = h[parent_pos]
- if parent > elt:
- # Swap out-of-place element with parent
- h[parent_pos], h[pos] = elt, parent
- parent_pos, pos = pos, parent_pos
- d[elt] = pos
- d[parent] = parent_pos
- else:
- # Invariant is satisfied
+ parent = heap[parent_pos]
+ if not newitem < parent:
+ break
+ heap[pos] = parent
+ position[parent] = pos
+ pos = parent_pos
+ heap[pos] = newitem
+ position[newitem] = pos
+
+ def _siftdown(self, start_pos, pos):
+ """Restore invariant. keep swapping with parent until smaller.
+
+ Built to mimic code for heapq._siftdown
+ only updating position dict too.
+ """
+ heap, position = self.heap, self.position
+ newitem = heap[pos]
+ # Follow the path to the root, moving parents down until finding a place
+ # newitem fits.
+ while pos > start_pos:
+ parent_pos = (pos - 1) >> 1
+ parent = heap[parent_pos]
+ if not newitem < parent:
break
- return pos
+ heap[pos] = parent
+ position[parent] = pos
+ pos = parent_pos
+ heap[pos] = newitem
+ position[newitem] = pos
diff --git a/networkx/utils/tests/test_mapped_queue.py b/networkx/utils/tests/test_mapped_queue.py
index 78ea91ec..89e251d4 100644
--- a/networkx/utils/tests/test_mapped_queue.py
+++ b/networkx/utils/tests/test_mapped_queue.py
@@ -1,4 +1,41 @@
-from networkx.utils.mapped_queue import MappedQueue
+import pytest
+from networkx.utils.mapped_queue import _HeapElement, MappedQueue
+
+
+def test_HeapElement_gtlt():
+ bar = _HeapElement(1.1, "a")
+ foo = _HeapElement(1, "b")
+ assert foo < bar
+ assert bar > foo
+ assert foo < 1.1
+ assert 1 < bar
+
+
+def test_HeapElement_eq():
+ bar = _HeapElement(1.1, "a")
+ foo = _HeapElement(1, "a")
+ assert foo == bar
+ assert bar == foo
+ assert foo == "a"
+
+
+def test_HeapElement_iter():
+ foo = _HeapElement(1, "a")
+ bar = _HeapElement(1.1, (3, 2, 1))
+ assert list(foo) == [1, "a"]
+ assert list(bar) == [1.1, 3, 2, 1]
+
+
+def test_HeapElement_getitem():
+ foo = _HeapElement(1, "a")
+ bar = _HeapElement(1.1, (3, 2, 1))
+ assert foo[1] == "a"
+ assert foo[0] == 1
+ assert bar[0] == 1.1
+ assert bar[2] == 2
+ assert bar[3] == 1
+ pytest.raises(IndexError, bar.__getitem__, 4)
+ pytest.raises(IndexError, foo.__getitem__, 2)
class TestMappedQueue:
@@ -6,13 +43,12 @@ class TestMappedQueue:
pass
def _check_map(self, q):
- d = {elt: pos for pos, elt in enumerate(q.h)}
- assert d == q.d
+ assert q.position == {elt: pos for pos, elt in enumerate(q.heap)}
def _make_mapped_queue(self, h):
q = MappedQueue()
- q.h = h
- q.d = {elt: pos for pos, elt in enumerate(h)}
+ q.heap = h
+ q.position = {elt: pos for pos, elt in enumerate(h)}
return q
def test_heapify(self):
@@ -37,7 +73,7 @@ class TestMappedQueue:
h_sifted = [2]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftup_one_child(self):
@@ -45,7 +81,7 @@ class TestMappedQueue:
h_sifted = [0, 2]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftup_left_child(self):
@@ -53,7 +89,7 @@ class TestMappedQueue:
h_sifted = [0, 2, 1]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftup_right_child(self):
@@ -61,39 +97,39 @@ class TestMappedQueue:
h_sifted = [0, 1, 2]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftup_multiple(self):
h = [0, 1, 2, 4, 3, 5, 6]
- h_sifted = [1, 3, 2, 4, 0, 5, 6]
+ h_sifted = [0, 1, 2, 4, 3, 5, 6]
q = self._make_mapped_queue(h)
q._siftup(0)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_leaf(self):
h = [2]
h_sifted = [2]
q = self._make_mapped_queue(h)
- q._siftdown(0)
- assert q.h == h_sifted
+ q._siftdown(0, 0)
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_single(self):
h = [1, 0]
h_sifted = [0, 1]
q = self._make_mapped_queue(h)
- q._siftdown(len(h) - 1)
- assert q.h == h_sifted
+ q._siftdown(0, len(h) - 1)
+ assert q.heap == h_sifted
self._check_map(q)
def test_siftdown_multiple(self):
h = [1, 2, 3, 4, 5, 6, 7, 0]
h_sifted = [0, 1, 3, 2, 5, 6, 7, 4]
q = self._make_mapped_queue(h)
- q._siftdown(len(h) - 1)
- assert q.h == h_sifted
+ q._siftdown(0, len(h) - 1)
+ assert q.heap == h_sifted
self._check_map(q)
def test_push(self):
@@ -102,7 +138,7 @@ class TestMappedQueue:
q = MappedQueue()
for elt in to_push:
q.push(elt)
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
def test_push_duplicate(self):
@@ -112,7 +148,7 @@ class TestMappedQueue:
for elt in to_push:
inserted = q.push(elt)
assert inserted
- assert q.h == h_sifted
+ assert q.heap == h_sifted
self._check_map(q)
inserted = q.push(1)
assert not inserted
@@ -122,9 +158,7 @@ class TestMappedQueue:
h_sorted = sorted(h)
q = self._make_mapped_queue(h)
q._heapify()
- popped = []
- for elt in sorted(h):
- popped.append(q.pop())
+ popped = [q.pop() for _ in range(len(h))]
assert popped == h_sorted
self._check_map(q)
@@ -133,25 +167,66 @@ class TestMappedQueue:
h_removed = [0, 2, 1, 6, 4, 5]
q = self._make_mapped_queue(h)
removed = q.remove(3)
- assert q.h == h_removed
+ assert q.heap == h_removed
def test_remove_root(self):
h = [0, 2, 1, 6, 3, 5, 4]
h_removed = [1, 2, 4, 6, 3, 5]
q = self._make_mapped_queue(h)
removed = q.remove(0)
- assert q.h == h_removed
+ assert q.heap == h_removed
def test_update_leaf(self):
h = [0, 20, 10, 60, 30, 50, 40]
h_updated = [0, 15, 10, 60, 20, 50, 40]
q = self._make_mapped_queue(h)
removed = q.update(30, 15)
- assert q.h == h_updated
+ assert q.heap == h_updated
def test_update_root(self):
h = [0, 20, 10, 60, 30, 50, 40]
h_updated = [10, 20, 35, 60, 30, 50, 40]
q = self._make_mapped_queue(h)
removed = q.update(0, 35)
- assert q.h == h_updated
+ assert q.heap == h_updated
+
+
+class TestMappedDict(TestMappedQueue):
+ def _make_mapped_queue(self, h):
+ priority_dict = {elt: elt for elt in h}
+ return MappedQueue(priority_dict)
+
+ def test_push(self):
+ to_push = [6, 1, 4, 3, 2, 5, 0]
+ h_sifted = [0, 2, 1, 6, 3, 5, 4]
+ q = MappedQueue()
+ for elt in to_push:
+ q.push(elt, priority=elt)
+ assert q.heap == h_sifted
+ self._check_map(q)
+
+ def test_push_duplicate(self):
+ to_push = [2, 1, 0]
+ h_sifted = [0, 2, 1]
+ q = MappedQueue()
+ for elt in to_push:
+ inserted = q.push(elt, priority=elt)
+ assert inserted
+ assert q.heap == h_sifted
+ self._check_map(q)
+ inserted = q.push(1, priority=1)
+ assert not inserted
+
+ def test_update_leaf(self):
+ h = [0, 20, 10, 60, 30, 50, 40]
+ h_updated = [0, 15, 10, 60, 20, 50, 40]
+ q = self._make_mapped_queue(h)
+ removed = q.update(30, 15, priority=15)
+ assert q.heap == h_updated
+
+ def test_update_root(self):
+ h = [0, 20, 10, 60, 30, 50, 40]
+ h_updated = [10, 20, 35, 60, 30, 50, 40]
+ q = self._make_mapped_queue(h)
+ removed = q.update(0, 35, priority=35)
+ assert q.heap == h_updated