Source code for pyagc.models.gcsbm

from typing import Any

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset

from pyagc.clusters import SBMClusterHead, SBMMatchClusterHead
from pyagc.models.base import ClusteringModel, LossOutput
from pyagc.utils import filter_kwargs


[docs]class GCSBM(ClusteringModel): r""" Stochastic Block Model (GCSBM) based clustering model from the paper `"Differentiable Community Detection with Graph Neural Networks and Stochastic Block Models" <https://openreview.net/forum?id=T1vdfm1THf>`_ (Arliss & Mueller, LoG 2025). This model learns node embeddings via a GNN encoder and performs clustering by maximizing the likelihood of an GCSBM-based generative model. It supports multiple GCSBM variants including Bernoulli, Poisson, and their degree-corrected versions. The model optimizes: .. math:: \mathcal{L} = m^{-1} \mathcal{L}_{GCSBM} + \alpha \|\mathbf{1}_k - \mathrm{diag}(\hat{\Theta})\|_F where :math:`\mathcal{L}_{GCSBM}` is the negative log-likelihood, :math:`m` is the number of edges, and :math:`\alpha` controls regularization strength. **Supported Variants:** - :obj:`'bernoulli'`: Standard Bernoulli GCSBM - :obj:`'poisson'`: Partial Poisson GCSBM (suitable for simple graphs) - :obj:`'bernoulli-dc'`: Degree-corrected Bernoulli GCSBM - :obj:`'poisson-dc'`: Degree-corrected Poisson GCSBM - :obj:`'match'`: Graph Matching variant (fastest) Args: encoder (torch.nn.Module): Node encoder that outputs node embeddings. n_features (int): Feature dimension of the encoder outputs. n_clusters (int): Number of clusters. variant (str, optional): GCSBM variant to use. Options: :obj:`'bernoulli'`, :obj:`'poisson'`, :obj:`'bernoulli-dc'`, :obj:`'poisson-dc'`, :obj:`'match'`. (default: :obj:`'bernoulli'`) eta (float, optional): Negative sampling ratio. (default: :obj:`3.0`) alpha (float, optional): Regularization strength. (default: :obj:`1.0`) """ def __init__( self, encoder: nn.Module, n_features: int, n_clusters: int, variant: str = 'bernoulli', eta: float = 3.0, alpha: float = 1.0, ): super().__init__() self.encoder = encoder self.n_features = n_features self.n_clusters = n_clusters self.variant = variant self.eta = eta self.alpha = alpha # Create appropriate clustering head if variant == 'match': self.head = SBMMatchClusterHead(n_clusters, n_features) else: self.head = SBMClusterHead(n_clusters, n_features, variant=variant, eta=eta)
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder) self.head.reset_cluster_centers()
[docs] def embed(self, *args, **kwargs) -> Tensor: r"""Compute node embeddings via the encoder.""" return self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def forward(self, *args, **kwargs) -> Tensor: r"""Predict hard cluster assignments from current parameters.""" z = self.embed(*args, **kwargs) return self.head.cluster(z)
[docs] def loss(self, x: Tensor, edge_index: Tensor, **kwargs: Any) -> LossOutput: r""" Computes the GCSBM loss with multiple components. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices. Returns: LossOutput containing total loss and individual components. """ z = self.embed(x, edge_index, **kwargs) likelihood_loss, reg_loss = self.head(z=z, edge_index=edge_index) total_loss = likelihood_loss + self.alpha * reg_loss return LossOutput( total=total_loss, components={ 'likelihood': likelihood_loss.item(), 'reg': (self.alpha * reg_loss).item() } )
[docs] def loss_batch(self, batch: Data, **kwargs: Any) -> LossOutput: r""" Computes loss for a mini-batch with seed node slicing. Args: batch (Data): A mini-batch from the loader. Returns: LossOutput containing total loss and individual components. """ z = self.embed(batch.x, batch.edge_index) z = z[:batch.batch_size] # Extract edges within the batch batch_mask = (batch.edge_index[0] < batch.batch_size) & (batch.edge_index[1] < batch.batch_size) batch_edge_index = batch.edge_index[:, batch_mask] m = batch_edge_index.size(1) if m == 0: # No edges in batch, return zero loss device = z.device return LossOutput( total=torch.tensor(0.0, device=device), components={'likelihood': 0.0, 'reg': 0.0} ) likelihood_loss, reg_loss = self.head(z=z, edge_index=batch_edge_index) total_loss = likelihood_loss + self.alpha * reg_loss return LossOutput( total=total_loss, components={ 'likelihood': likelihood_loss.item(), 'reg': (self.alpha * reg_loss).item() } )
def __repr__(self): return (f"{self.__class__.__name__}(variant={self.variant}, " f"n_clusters={self.n_clusters}, alpha={self.alpha}, " f"encoder={self.encoder.__class__.__name__})")