Source code for pyagc.models.dmon

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset

from pyagc.clusters import DMoNClusterHead
from pyagc.models.base import ClusteringModel, LossOutput
from pyagc.utils import filter_kwargs


[docs]class DMoN(ClusteringModel): r""" The Deep Modularity Network (DMoN) is proposed in the `"Graph Clustering with Graph Neural Networks" <https://arxiv.org/abs/2006.16904>`_ paper (Tsitsulin et al., JMLR 2023). This model performs **unsupervised graph clustering** by combining a graph encoder (e.g., GCN, GraphSAGE) with the :class:`~pyagc.clusters.DMoNClusterHead`. The encoder produces node embeddings :math:`\mathbf{Z}`, which are then projected into soft cluster assignments :math:`\mathbf{S}`. The model jointly optimizes the modularity-based and collapse regularization objectives to learn meaningful community structures in the graph. The optimization objective consists of two losses: **(1) Spectral modularity loss:** .. math:: \mathcal{L}_m = - \frac{1}{2m} \mathrm{Tr}(\mathbf{S}^\top \mathbf{B} \mathbf{S}) where :math:`\mathbf{B} = \mathbf{A} - \frac{\mathbf{d}\mathbf{d}^\top}{2m}` is the modularity matrix and :math:`m` is the total number of edges. This term encourages nodes that are more densely connected than random to be assigned to the same cluster. **(2) Collapse regularization loss:** .. math:: \mathcal{L}_c = \frac{\sqrt{K}}{N} \left\| \sum_i \mathbf{S}_i^\top \right\|_F - 1 which prevents degenerate solutions by penalizing unbalanced or collapsed cluster assignments. The final training objective is a weighted combination of the two terms: .. math:: \mathcal{L} = \mathcal{L}_m + \lambda \mathcal{L}_c where :math:`\lambda` controls the strength of the regularization. Args: encoder (torch.nn.Module): Node encoder that outputs node embeddings. n_features (int): Feature dimension of the encoder outputs. n_clusters (int): Number of clusters. lam (float, optional): Regularization coefficient for the collapse loss :math:`\mathcal{L}_c`. (default: :obj:`1.0`) """ def __init__(self, encoder: nn.Module, n_features: int, n_clusters: int, lam: float = 1.0): super().__init__() self.encoder = encoder self.n_features = n_features self.n_clusters = n_clusters self.lam = lam # DMoN clustering head self.head = DMoNClusterHead(n_clusters, n_features)
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder) self.head.reset_cluster_centers()
[docs] def embed(self, *args, **kwargs) -> Tensor: r"""Compute node embeddings via the encoder.""" return self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def forward(self, *args, **kwargs) -> Tensor: r"""Predict hard cluster assignments from current parameters.""" z = self.embed(*args, **kwargs) return self.head.cluster(z)
[docs] def loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> LossOutput: r""" Computes the DMoN loss with multiple components. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices. Returns: LossOutput containing total loss and individual components. """ z = self.embed(x, edge_index, **kwargs) loss_m, loss_c = self.head(z=z, edge_index=edge_index) loss = loss_m + self.lam * loss_c return LossOutput( total=loss, components={ 'modularity': loss_m.item(), 'collapse': loss_c.item() } )
[docs] def loss_batch(self, batch: Data): r"""DMoN currently does not support mini-batch training.""" raise NotImplementedError(f"{self.__class__.__name__} does not support batch training.")
def __repr__(self): return f"{self.__class__.__name__}(lam={self.lam}, encoder={self.encoder.__class__.__name__})"