Source code for pyagc.models.magi

from typing import Union, List, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.inits import reset
from torch_geometric.sampler import NodeSamplerInput
# from torch_sparse import SparseTensor
from torch_geometric.typing import SparseTensor

from pyagc.models.base import TrainableModel, LossOutput
from pyagc.utils import filter_kwargs


[docs]class MAGI(TrainableModel): r"""The MAGI (Modularity-Aware Graph clustering via contrastive learnIng) model from the `"Revisiting Modularity Maximization for Graph Clustering: A Contrastive Learning Perspective" <https://arxiv.org/abs/2406.14288>`_ paper (Liu et al., KDD 2024). MAGI establishes the connection between modularity maximization and graph contrastive learning, where positive and negative samples are naturally guided by the modularity matrix. The model uses a community-aware pretext task based on two-stage random walks to capture high-order proximity within communities. The loss function follows InfoNCE-style contrastive learning: .. math:: \mathcal{L} = -\sum_{v \in \mathcal{B}} \sum_{u \in \mathcal{M}^+_v} \log \frac{\exp(\mathbf{z}_v^\top \mathbf{z}_u / \tau)} {\sum_{u \in \mathcal{M}^+_v} \exp(\mathbf{z}_v^\top \mathbf{z}_u / \tau) + \sum_{u' \in \mathcal{M}^-_v} \exp(\mathbf{z}_v^\top \mathbf{z}_{u'} / \tau)} where: - :math:`\mathcal{M}^+_v = \{u \mid B_{vu} > 0\}` are positive samples (same community). - :math:`\mathcal{M}^-_v = \{u \mid B_{vu} \leq 0\}` are negative samples (different communities). - :math:`B_{vu}` is the modularity coefficient computed via two-stage random walks. - :math:`\tau` is the temperature parameter. Args: encoder (torch.nn.Module): The GNN encoder module. tau (float, optional): Temperature parameter for contrastive loss. (default: :obj:`0.5`) scale_embeddings (bool, optional): Whether to apply min-max scaling to embeddings before normalization. (default: :obj:`True`) """ def __init__( self, encoder: Module, tau: float = 0.5, scale_embeddings: bool = True, ): super().__init__() self.encoder = encoder self.tau = tau self.scale_embeddings = scale_embeddings
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder)
[docs] def embed(self, *args, **kwargs) -> Tensor: r"""Computes node embeddings. Returns: Node embeddings of shape :obj:`(num_nodes, hidden_dim)`. """ z = self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs)) if self.scale_embeddings: z = self._scale(z) z = F.normalize(z, p=2, dim=-1) return z
def _scale(self, z: Tensor) -> Tensor: r"""Applies min-max scaling to embeddings.""" z_max = z.max(dim=-1, keepdim=True)[0] z_min = z.min(dim=-1, keepdim=True)[0] z_std = (z - z_min) / (z_max - z_min + 1e-20) return z_std def _compute_loss( self, z: Tensor, pos_pairs: Tensor, ) -> Tensor: r"""Computes the InfoNCE contrastive loss based on positive sample pairs. Args: z (torch.Tensor): Normalized embeddings of shape :obj:`(num_nodes, hidden_dim)`. pos_pairs (torch.Tensor): Positive sample pairs of shape :obj:`(2, num_pos_edges)`, where each column :obj:`[i, j]` indicates that nodes :obj:`i` and :obj:`j` are in the same community. Returns: Scalar loss tensor. """ # Compute full similarity matrix: S[i,j] = z_i^T z_j sim = torch.matmul(z, z.T) / self.tau # (num_nodes, num_nodes) # For numerical stability, subtract max along each row sim = sim - sim.max(dim=1, keepdim=True)[0].detach() # Compute exponentials of similarities exp_sim = torch.exp(sim) # (num_nodes, num_nodes) # Compute denominator: sum over all non-self nodes # Use subtraction to exclude diagonal (self-loops) without in-place operations denom = exp_sim.sum(dim=1, keepdim=True) - torch.diag(exp_sim).unsqueeze(1) # (num_nodes, 1) # Extract positive pairs src, dst = pos_pairs # (num_pos_edges,) # Compute log probabilities for positive pairs log_prob = sim[src, dst] - torch.log(denom[src].squeeze(-1) + 1e-15) # Average over all positive pairs loss = -log_prob.mean() return loss
[docs] def loss(self, x: Tensor, edge_index: Tensor, pos_pairs: Tensor, **kwargs) -> LossOutput: r"""Computes the MAGI contrastive loss for full-graph training. Args: x (torch.Tensor): Node feature matrix of shape :obj:`(num_nodes, num_features)`. edge_index (torch.Tensor): Edge indices of shape :obj:`(2, num_edges)`. pos_pairs (torch.Tensor): Precomputed positive sample pairs of shape :obj:`(2, num_pos_edges)` based on modularity matrix. **kwargs: Additional arguments for the encoder. Returns: LossOutput containing the total loss. """ z = self.embed(x, edge_index, **kwargs) loss = self._compute_loss(z, pos_pairs) return LossOutput( total=loss, components={'contrastive': loss.item()} )
[docs] def loss_batch(self, batch: Data) -> LossOutput: r"""Computes loss for a mini-batch with expanded seed nodes. In mini-batch training, the positive pairs are defined only among the expanded batch seed nodes (obtained via two-stage random walks). Args: batch (Data): A mini-batch from the MAGI loader containing: - :obj:`x`: Node features (including neighbors) - :obj:`edge_index`: Sampled edges - :obj:`pos_pairs`: Positive pairs for expanded batch seed nodes - :obj:`expanded_batch_size`: Number of expanded seed nodes Returns: LossOutput containing the total loss. """ z = self.embed(batch.x, batch.edge_index) # Slice embeddings to expanded batch seed nodes z_batch = z[:batch.expanded_batch_size] # Positive pairs are already relative to expanded batch nodes loss = self._compute_loss(z_batch, batch.pos_pairs) return LossOutput( total=loss, components={'contrastive': loss.item()} )
def __repr__(self) -> str: return (f"{self.__class__.__name__}(" f"encoder={self.encoder}, " f"tau={self.tau})")
class MAGIRandomWalkSampler: r"""Two-stage random walk sampler for MAGI model. This sampler implements the two-stage random walk strategy described in the MAGI paper to construct mini-batch modularity matrices and positive sample pairs. Stage 1 (S1): Sample multiple sub-communities by performing random walks from root nodes and filtering nodes visited more than the mean frequency. Stage 2 (S2): Compute similarity matrix via random walks on the expanded batch and derive modularity-based positive/negative pairs. Args: adj (torch_sparse.SparseTensor): Full graph adjacency matrix in SparseTensor format. num_walks (int, optional): Number of random walks per node (wt in paper). (default: :obj:`20`) walk_length (int, optional): Length of each random walk (wl in paper). (default: :obj:`4`) """ def __init__( self, adj: SparseTensor, num_walks: int = 20, walk_length: int = 4, ): self.adj = adj.cpu() self.num_walks = num_walks self.walk_length = walk_length def stage1_expand_batch(self, root_nodes: Tensor) -> Tensor: r"""Stage 1: Expand root nodes to sub-communities via random walks. For each root node, perform random walks and keep nodes that are visited more frequently than the average. Args: root_nodes (torch.Tensor): Root node indices of shape :obj:`(num_roots,)`. Returns: Expanded batch nodes of shape :obj:`(expanded_size,)`. """ num_roots = root_nodes.size(0) # Repeat root nodes for multiple walks start_nodes = root_nodes.repeat_interleave(self.num_walks) # (num_roots * num_walks,) # Perform random walks (excluding starting nodes) walks = self.adj.random_walk(start_nodes, self.walk_length) # Handle both tuple and tensor returns if isinstance(walks, tuple): walks = walks[0] walks = walks[:, 1:] # Exclude start node, shape: (num_roots * num_walks, walk_length) # Reshape to (num_roots, num_walks * walk_length) walks = walks.reshape(num_roots, -1) # Collect expanded batch expanded_batch = [] for i in range(num_roots): rw_nodes, counts = torch.unique(walks[i], return_counts=True) mean_count = counts.float().mean() # Keep nodes visited more than mean mask = counts > mean_count selected = rw_nodes[mask].tolist() expanded_batch.extend(selected) # Add original root nodes expanded_batch.extend(root_nodes.tolist()) # Remove duplicates and convert to tensor expanded_batch = torch.tensor(list(set(expanded_batch)), dtype=torch.long) return expanded_batch def stage2_compute_positive_pairs(self, expanded_batch: Tensor) -> Tensor: r"""Stage 2: Compute positive sample pairs based on modularity matrix. Perform random walks on expanded batch nodes to compute a similarity matrix, then derive modularity coefficients and identify positive pairs. Args: expanded_batch (torch.Tensor): Expanded batch node indices of shape :obj:`(expanded_size,)`. Returns: Positive sample pairs of shape :obj:`(2, num_pos_pairs)`, where each column :obj:`[i, j]` contains local indices (relative to expanded_batch) of nodes in the same community. """ batch_size = expanded_batch.size(0) # Repeat for multiple walks start_nodes = expanded_batch.repeat_interleave(self.num_walks) # Perform random walks walks = self.adj.random_walk(start_nodes, self.walk_length) if isinstance(walks, tuple): walks = walks[0] walks = walks[:, 1:] # Exclude start node walks = walks.reshape(batch_size, -1) # Build visit count matrix row_indices = [] col_indices = [] values = [] # Create mapping from global to local indices global_to_local = {int(node): i for i, node in enumerate(expanded_batch)} for i in range(batch_size): visited_nodes, counts = torch.unique(walks[i], return_counts=True) for node, count in zip(visited_nodes, counts): node_int = int(node) if node_int in global_to_local: j = global_to_local[node_int] row_indices.append(i) col_indices.append(j) values.append(int(count)) # Build sparse visit matrix if len(values) > 0: row = torch.tensor(row_indices, dtype=torch.long) col = torch.tensor(col_indices, dtype=torch.long) val = torch.tensor(values, dtype=torch.float) # Compute similarity matrix S[i,j] = visit_count(i->j) / total_visits(i) row_sums = torch.zeros(batch_size) row_sums.scatter_add_(0, row, val) # Normalize by row sums sim_values = val / (row_sums[row] + 1e-15) # Compute modularity matrix: B[i,j] = S[i,j] - 1/|batch| mod_values = sim_values - (1.0 / batch_size) # Extract positive pairs (B[i,j] > 0) pos_mask = mod_values > 0 pos_row = row[pos_mask] pos_col = col[pos_mask] pos_pairs = torch.stack([pos_row, pos_col], dim=0) # (2, num_pos_pairs) else: # Fallback: no valid pairs found, return empty tensor pos_pairs = torch.empty((2, 0), dtype=torch.long) return pos_pairs def sample(self, root_nodes: Tensor) -> Tuple[Tensor, Tensor]: r"""Performs two-stage sampling for a batch of root nodes. Args: root_nodes (torch.Tensor): Root node indices of shape :obj:`(num_roots,)`. Returns: Tuple of: - Expanded batch nodes of shape :obj:`(expanded_size,)` - Positive sample pairs of shape :obj:`(2, num_pos_pairs)` """ # Stage 1: Expand batch expanded_batch = self.stage1_expand_batch(root_nodes) # Stage 2: Compute positive pairs pos_pairs = self.stage2_compute_positive_pairs(expanded_batch) return expanded_batch, pos_pairs class MAGINeighborLoader(NeighborLoader): r"""A specialized neighbor loader for MAGI that performs two-stage random walk sampling to construct mini-batches with community-aware structure. This loader extends :class:`~torch_geometric.loader.NeighborLoader` by: 1. Expanding initial seed nodes to sub-communities (Stage 1) 2. Computing positive sample pairs based on modularity (Stage 2) 3. Sampling neighbors for the expanded batch nodes Args: data (Data or HeteroData): The graph data object. num_neighbors (List[int]): Number of neighbors to sample per layer. num_walks (int, optional): Number of random walks per node for MAGI sampling. (default: :obj:`20`) walk_length (int, optional): Length of random walks. (default: :obj:`4`) batch_size (int, optional): Number of seed nodes per batch. (default: :obj:`128`) **kwargs: Additional arguments for :class:`NeighborLoader`. Example: >>> from torch_geometric.datasets import Planetoid >>> from pyagc.models.magi import MAGINeighborLoader >>> data = Planetoid(root='data', name='Cora')[0] >>> loader = MAGINeighborLoader( ... data, ... num_neighbors=[10, 10], ... num_walks=20, ... walk_length=4, ... batch_size=128, ... ) >>> for batch in loader: ... # batch.x: node features ... # batch.edge_index: sampled edges ... # batch.pos_pairs: positive pairs for contrastive loss ... # batch.expanded_batch_size: size of expanded batch ... pass """ def __init__( self, data: Union[Data, HeteroData], num_neighbors: List[int], num_walks: int = 20, walk_length: int = 4, batch_size: int = 128, **kwargs, ): # Initialize parent NeighborLoader # We'll override the sampling behavior in collate_fn super().__init__( data, num_neighbors=num_neighbors, batch_size=batch_size, **kwargs, ) # Build adjacency matrix for random walks edge_index = data.edge_index.cpu() num_nodes = data.num_nodes self.adj = SparseTensor( row=edge_index[0], col=edge_index[1], sparse_sizes=(num_nodes, num_nodes), ) self.adj.fill_value_(1.0) # Initialize MAGI random walk sampler self.magi_sampler = MAGIRandomWalkSampler( adj=self.adj, num_walks=num_walks, walk_length=walk_length, ) def collate_fn(self, index: Union[Tensor, List[int]]) -> Data: r"""Modified collate function that performs MAGI two-stage sampling. Args: index (Tensor or List[int]): Indices of seed nodes in this batch. Returns: Data object containing sampled subgraph with additional attributes: - :obj:`pos_pairs`: Positive sample pairs - :obj:`expanded_batch_size`: Number of expanded seed nodes """ if not isinstance(index, Tensor): index = torch.tensor(index, dtype=torch.long) # Get original seed nodes input_data: NodeSamplerInput = self.input_data[index] original_seed_nodes = input_data.node # Stage 1: Expand batch via random walks expanded_batch, pos_pairs = self.magi_sampler.sample(original_seed_nodes) expanded_batch_size = expanded_batch.size(0) # Create new input data with expanded batch as seed nodes expanded_input_data = NodeSamplerInput( # input_id=torch.arange(expanded_batch_size), input_id=expanded_batch, node=expanded_batch, time=None, input_type=input_data.input_type, ) # Perform neighbor sampling on expanded batch out = self.node_sampler.sample_from_nodes(expanded_input_data) if self.filter_per_worker: out = self.filter_fn(out) # Add MAGI-specific attributes out.pos_pairs = pos_pairs out.expanded_batch_size = expanded_batch_size out.batch_size = index.size(0) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'num_walks={self.magi_sampler.num_walks}, ' f'walk_length={self.magi_sampler.walk_length})') def precompute_full_graph_positive_pairs( data: Data, num_walks: int = 100, walk_length: int = 2, ) -> Tensor: r"""Precomputes positive sample pairs for full-graph training. This function performs the two-stage random walk on all nodes in the graph and returns positive pairs based on the modularity matrix. Args: data (Data): The graph data object. num_walks (int, optional): Number of random walks per node. (default: :obj:`100`) walk_length (int, optional): Length of random walks. (default: :obj:`2`) Returns: Positive sample pairs of shape :obj:`(2, num_pos_pairs)`. Example: >>> from torch_geometric.datasets import Planetoid >>> data = Planetoid(root='data', name='Cora')[0] >>> pos_pairs = precompute_full_graph_positive_pairs(data) >>> data.pos_pairs = pos_pairs # Store in data object """ num_nodes = data.num_nodes edge_index = data.edge_index.cpu() # Build adjacency matrix adj = SparseTensor( row=edge_index[0], col=edge_index[1], sparse_sizes=(num_nodes, num_nodes), ) adj.fill_value_(1.0) # Stage 1: For full graph, all nodes are considered (no expansion needed) all_nodes = torch.arange(num_nodes) # Stage 2: Compute similarity via random walks start_nodes = all_nodes.repeat_interleave(num_walks) walks = adj.random_walk(start_nodes, walk_length) if isinstance(walks, tuple): walks = walks[0] walks = walks[:, 1:] # Exclude start node walks = walks.reshape(num_nodes, -1) # Build visit count matrix row_indices = [] col_indices = [] values = [] for i in range(num_nodes): visited_nodes, counts = torch.unique(walks[i], return_counts=True) for node, count in zip(visited_nodes, counts): j = int(node) if j < num_nodes: # Valid node row_indices.append(i) col_indices.append(j) values.append(int(count)) # Build sparse similarity matrix row = torch.tensor(row_indices, dtype=torch.long) col = torch.tensor(col_indices, dtype=torch.long) val = torch.tensor(values, dtype=torch.float) # Compute row sums for normalization row_sums = torch.zeros(num_nodes) row_sums.scatter_add_(0, row, val) # Similarity matrix: S[i,j] = visit_count(i->j) / total_visits(i) sim_values = val / (row_sums[row] + 1e-15) # Modularity matrix: B[i,j] = S[i,j] - 1/N mod_values = sim_values - (1.0 / num_nodes) # Extract positive pairs (B[i,j] > 0) pos_mask = mod_values > 0 pos_row = row[pos_mask] pos_col = col[pos_mask] pos_pairs = torch.stack([pos_row, pos_col], dim=0) return pos_pairs