diff options
author | Ross Barnowski <rossbar@berkeley.edu> | 2021-10-25 09:42:06 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-25 09:42:06 -0700 |
commit | 7deded8bfbebf2401005c6fe2d4fa7f67cf9f74f (patch) | |
tree | 88a48d17287a84a6d36251824e67f8c7f27c0039 | |
parent | a06a720abce00f6fbff2a165125d94fdfdc1f9f5 (diff) | |
download | networkx-7deded8bfbebf2401005c6fe2d4fa7f67cf9f74f.tar.gz |
Refactor node_classification to improve conciseness and readability (#5144)
* Rm _init_label_matrix indirection.
* Factor out _predict indirection.
* Factor out _propagate one-liner.
* Factor out _build_base_matrix fn.
* Minor cleanups.
* Fix up docstrings.
* Rm internal _build_propagation_matrix internal fn.
* Rm unnecessary binding of local variable.
-rw-r--r-- | networkx/algorithms/node_classification/hmn.py | 90 | ||||
-rw-r--r-- | networkx/algorithms/node_classification/lgc.py | 87 | ||||
-rw-r--r-- | networkx/algorithms/node_classification/utils.py | 65 |
3 files changed, 38 insertions, 204 deletions
diff --git a/networkx/algorithms/node_classification/hmn.py b/networkx/algorithms/node_classification/hmn.py index 9cdb56ff..6b7fab77 100644 --- a/networkx/algorithms/node_classification/hmn.py +++ b/networkx/algorithms/node_classification/hmn.py @@ -9,12 +9,7 @@ In ICML (Vol. 3, pp. 912-919). import networkx as nx from networkx.utils.decorators import not_implemented_for -from networkx.algorithms.node_classification.utils import ( - _get_label_info, - _init_label_matrix, - _propagate, - _predict, -) +from networkx.algorithms.node_classification.utils import _get_label_info __all__ = ["harmonic_function"] @@ -33,13 +28,13 @@ def harmonic_function(G, max_iter=30, label_name="label"): Returns ------- - predicted : array, shape = [n_samples] - Array of predicted labels + predicted : list + List of length ``len(G)`` with the predicted labels for each node. Raises ------ NetworkXError - If no nodes on `G` has `label_name`. + If no nodes in `G` have attribute `label_name`. Examples -------- @@ -65,72 +60,29 @@ def harmonic_function(G, max_iter=30, label_name="label"): import scipy as sp import scipy.sparse # call as sp.sparse - def _build_propagation_matrix(X, labels): - """Build propagation matrix of Harmonic function - - Parameters - ---------- - X : scipy sparse matrix, shape = [n_samples, n_samples] - Adjacency matrix - labels : array, shape = [n_samples, 2] - Array of pairs of node id and label id - - Returns - ------- - P : scipy sparse matrix, shape = [n_samples, n_samples] - Propagation matrix - - """ - degrees = X.sum(axis=0).A[0] - degrees[degrees == 0] = 1 # Avoid division by 0 - D = sp.sparse.diags((1.0 / degrees), offsets=0) - P = (D @ X).tolil() - P[labels[:, 0]] = 0 # labels[:, 0] indicates IDs of labeled nodes - return P - - def _build_base_matrix(X, labels, n_classes): - """Build base matrix of Harmonic function - - Parameters - ---------- - X : scipy sparse matrix, shape = [n_samples, n_samples] - Adjacency matrix - labels : array, shape = [n_samples, 2] - Array of pairs of node id and label id - n_classes : integer - The number of classes (distinct labels) on the input graph - - Returns - ------- - B : array, shape = [n_samples, n_classes] - Base matrix - """ - n_samples = X.shape[0] - B = np.zeros((n_samples, n_classes)) - B[labels[:, 0], labels[:, 1]] = 1 - return B - X = nx.to_scipy_sparse_matrix(G) # adjacency matrix labels, label_dict = _get_label_info(G, label_name) if labels.shape[0] == 0: raise nx.NetworkXError( - "No node on the input graph is labeled by '" + label_name + "'." + f"No node on the input graph is labeled by '{label_name}'." ) n_samples = X.shape[0] n_classes = label_dict.shape[0] - - F = _init_label_matrix(n_samples, n_classes) - - P = _build_propagation_matrix(X, labels) - B = _build_base_matrix(X, labels, n_classes) - - remaining_iter = max_iter - while remaining_iter > 0: - F = _propagate(P, F, B) - remaining_iter -= 1 - - predicted = _predict(F, label_dict) - - return predicted + F = np.zeros((n_samples, n_classes)) + + # Build propagation matrix + degrees = X.sum(axis=0).A[0] + degrees[degrees == 0] = 1 # Avoid division by 0 + D = sp.sparse.diags((1.0 / degrees), offsets=0) + P = (D @ X).tolil() + P[labels[:, 0]] = 0 # labels[:, 0] indicates IDs of labeled nodes + # Build base matrix + B = np.zeros((n_samples, n_classes)) + B[labels[:, 0], labels[:, 1]] = 1 + + for _ in range(max_iter): + F = (P @ F) + B + + return label_dict[np.argmax(F, axis=1)].tolist() diff --git a/networkx/algorithms/node_classification/lgc.py b/networkx/algorithms/node_classification/lgc.py index f873c2b2..ca4daa03 100644 --- a/networkx/algorithms/node_classification/lgc.py +++ b/networkx/algorithms/node_classification/lgc.py @@ -9,12 +9,7 @@ Advances in neural information processing systems, 16(16), 321-328. import networkx as nx from networkx.utils.decorators import not_implemented_for -from networkx.algorithms.node_classification.utils import ( - _get_label_info, - _init_label_matrix, - _propagate, - _predict, -) +from networkx.algorithms.node_classification.utils import _get_label_info __all__ = ["local_and_global_consistency"] @@ -35,13 +30,13 @@ def local_and_global_consistency(G, alpha=0.99, max_iter=30, label_name="label") Returns ------- - predicted : array, shape = [n_samples] - Array of predicted labels + predicted : list + List of length ``len(G)`` with the predicted labels for each node. Raises ------ NetworkXError - If no nodes on `G` has `label_name`. + If no nodes in `G` have attribute `label_name`. Examples -------- @@ -57,7 +52,6 @@ def local_and_global_consistency(G, alpha=0.99, max_iter=30, label_name="label") >>> predicted ['A', 'A', 'B', 'B'] - References ---------- Zhou, D., Bousquet, O., Lal, T. N., Weston, J., & Schölkopf, B. (2004). @@ -68,75 +62,28 @@ def local_and_global_consistency(G, alpha=0.99, max_iter=30, label_name="label") import scipy as sp import scipy.sparse # call as sp.sparse - def _build_propagation_matrix(X, labels, alpha): - """Build propagation matrix of Local and global consistency - - Parameters - ---------- - X : scipy sparse matrix, shape = [n_samples, n_samples] - Adjacency matrix - labels : array, shape = [n_samples, 2] - Array of pairs of node id and label id - alpha : float - Clamping factor - - Returns - ------- - S : scipy sparse matrix, shape = [n_samples, n_samples] - Propagation matrix - - """ - degrees = X.sum(axis=0).A[0] - degrees[degrees == 0] = 1 # Avoid division by 0 - D2 = np.sqrt(sp.sparse.diags((1.0 / degrees), offsets=0)) - S = alpha * ((D2 @ X) @ D2) - return S - - def _build_base_matrix(X, labels, alpha, n_classes): - """Build base matrix of Local and global consistency - - Parameters - ---------- - X : scipy sparse matrix, shape = [n_samples, n_samples] - Adjacency matrix - labels : array, shape = [n_samples, 2] - Array of pairs of node id and label id - alpha : float - Clamping factor - n_classes : integer - The number of classes (distinct labels) on the input graph - - Returns - ------- - B : array, shape = [n_samples, n_classes] - Base matrix - """ - - n_samples = X.shape[0] - B = np.zeros((n_samples, n_classes)) - B[labels[:, 0], labels[:, 1]] = 1 - alpha - return B - X = nx.to_scipy_sparse_matrix(G) # adjacency matrix labels, label_dict = _get_label_info(G, label_name) if labels.shape[0] == 0: raise nx.NetworkXError( - "No node on the input graph is labeled by '" + label_name + "'." + f"No node on the input graph is labeled by '{label_name}'." ) n_samples = X.shape[0] n_classes = label_dict.shape[0] - F = _init_label_matrix(n_samples, n_classes) - - P = _build_propagation_matrix(X, labels, alpha) - B = _build_base_matrix(X, labels, alpha, n_classes) + F = np.zeros((n_samples, n_classes)) - remaining_iter = max_iter - while remaining_iter > 0: - F = _propagate(P, F, B) - remaining_iter -= 1 + # Build propagation matrix + degrees = X.sum(axis=0).A[0] + degrees[degrees == 0] = 1 # Avoid division by 0 + D2 = np.sqrt(sp.sparse.diags((1.0 / degrees), offsets=0)) + P = alpha * ((D2 @ X) @ D2) + # Build base matrix + B = np.zeros((n_samples, n_classes)) + B[labels[:, 0], labels[:, 1]] = 1 - alpha - predicted = _predict(F, label_dict) + for _ in range(max_iter): + F = (P @ F) + B - return predicted + return label_dict[np.argmax(F, axis=1)].tolist() diff --git a/networkx/algorithms/node_classification/utils.py b/networkx/algorithms/node_classification/utils.py index 43f4fd0a..f7d7ac21 100644 --- a/networkx/algorithms/node_classification/utils.py +++ b/networkx/algorithms/node_classification/utils.py @@ -1,24 +1,3 @@ -def _propagate(P, F, B): - """Propagate labels by one step - - Parameters - ---------- - P : scipy sparse matrix, shape = [n_samples, n_samples] - Propagation matrix - F : numpy array, shape = [n_samples, n_classes] - Label matrix - B : numpy array, shape = [n_samples, n_classes] - Base matrix - - Returns - ---------- - F_new : array, shape = [n_samples, n_classes] - Label matrix - """ - F_new = (P @ F) + B - return F_new - - def _get_label_info(G, label_name): """Get and return information of labels from the input graph @@ -53,47 +32,3 @@ def _get_label_info(G, label_name): [label for label, _ in sorted(label_to_id.items(), key=lambda x: x[1])] ) return (labels, label_dict) - - -def _init_label_matrix(n_samples, n_classes): - """Create and return zero matrix - - Parameters - ---------- - n_samples : integer - The number of nodes (samples) on the input graph - n_classes : integer - The number of classes (distinct labels) on the input graph - - Returns - ---------- - F : numpy array, shape = [n_samples, n_classes] - Label matrix - """ - import numpy as np - - F = np.zeros((n_samples, n_classes)) - return F - - -def _predict(F, label_dict): - """Predict labels by learnt label matrix - - Parameters - ---------- - F : numpy array, shape = [n_samples, n_classes] - Learnt (resulting) label matrix - label_dict : numpy array, shape = [n_classes] - Array of labels - i-th element contains the label corresponding label ID `i` - - Returns - ---------- - predicted : numpy array, shape = [n_samples] - Array of predicted labels - """ - import numpy as np - - predicted_label_ids = np.argmax(F, axis=1) - predicted = label_dict[predicted_label_ids].tolist() - return predicted |