Source code for pyagc.models.ns4gc

from typing import Tuple

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset
from torch_geometric.transforms import BaseTransform

from pyagc.models.base import TrainableModel, LossOutput
from pyagc.utils import filter_kwargs


[docs]class NS4GC(TrainableModel): r""" The NS4GC (Node Similarity-guided Contrastive Graph Clustering) model is proposed in the `"Reliable Node Similarity Matrix Guided Contrastive Graph Clustering" <https://arxiv.org/abs/2408.03765>`_ paper (Liu et al., TKDE 2024). NS4GC aims to learn node representations suitable for clustering via contrastive learning guided by a node similarity matrix computed from two augmented views of the graph. The model uses the following contrastive loss function: .. math:: \mathcal{L} = \mathcal{L}_{\text{ali}} + \lambda \mathcal{L}_{\text{nei}} + \gamma \mathcal{L}_{\text{spa}} where: - :math:`\boldsymbol{S} = \boldsymbol{Z}^{(1)} (\boldsymbol{Z}^{(2)})^{\top}` is the similarity matrix computed by inner product between two views; - :math:`\mathcal{L}_{\text{ali}} = -\frac{1}{N} \sum_i \boldsymbol{S}_{ii}` encourages aligned embeddings from two views; - :math:`\mathcal{L}_{\text{nei}} = -\frac{1}{|\mathcal{E}|} \sum_{(i,j) \in \mathcal{E}} \boldsymbol{S}_{ij}` enforces consistency for neighbors; - :math:`\mathcal{L}_{\text{spa}} = \mathbb{E}_{(i,j) \notin \mathcal{E}} \sigma((\boldsymbol{S}_{ij} - s) / \tau)` encourages sparsity of similarity values between non-neighbor pairs. :math:`\sigma(x) = 1 / (1 + \exp(-x))` is the sigmoid function. Args: encoder (torch.nn.Module): The encoder shared across both views. transform1 (torch_geometric.transforms.BaseTransform): The 1-st graph view transformation. transform2 (torch_geometric.transforms.BaseTransform): The 2-nd graph view transformation. s (float): Threshold for the sparsity loss term (default: :obj:`0.6`). tau (float): Temperature for the sigmoid function in sparsity loss (default: :obj:`0.1`). lam (float): Weight for the neighborhood consistency term (default: :obj:`1.0`). gam (float): Weight for the sparsity loss term (default: :obj:`1.0`). Example: >>> from pyagc.models import NS4GC >>> from pyagc.transforms import GSSLTransform >>> from pyagc.encoders import create_tuned_gnn >>> >>> # Create encoder >>> encoder = create_tuned_gnn('gcn', in_channels=128, hidden_channels=64, num_layers=2) >>> >>> # Create augmentation transforms >>> transform1 = GSSLTransform(p_feat_mask=0.2, p_edge_drop=0.3) >>> transform2 = GSSLTransform(p_feat_mask=0.2, p_edge_drop=0.3) >>> >>> # Create model >>> model = NS4GC( ... encoder=encoder, ... transform1=transform1, ... transform2=transform2, ... s=0.6, ... tau=0.1, ... lam=1.0, ... gam=1.0 ... ) """ def __init__(self, encoder: Module, transform1: BaseTransform, transform2: BaseTransform, s: float = 0.6, tau: float = 0.1, lam: float = 1.0, gam: float = 1.0): super(NS4GC, self).__init__() self.encoder = encoder self.transform1 = transform1 self.transform2 = transform2 self.s = s self.tau = tau self.lam = lam self.gam = gam
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder)
[docs] def embed(self, *args, **kwargs) -> Tensor: r"""Computes node embeddings.""" return self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: r"""Generates embeddings from two graph augmentations.""" data1 = self.transform1(*args, **kwargs) data2 = self.transform2(*args, **kwargs) z1 = self.encoder(**data1) z2 = self.encoder(**data2) return z1, z2
def _compute_loss(self, z1: Tensor, z2: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r""" Computes the total loss and its individual components based on two-view embeddings. Args: z1 (torch.Tensor): The first node embeddings. z2 (torch.Tensor): The second node embeddings. edge_index (torch.Tensor): Edge indices. Returns: Tuple of (total_loss, alignment_loss, neighbor_loss, sparsity_loss). """ device = z1.device z1 = F.normalize(z1, p=2, dim=-1) z2 = F.normalize(z2, p=2, dim=-1) S = z1 @ z2.T num_nodes = z1.size(0) loss_ali = -torch.diag(S).mean() mask = torch.ones((num_nodes, num_nodes), device=device, dtype=torch.bool) mask.fill_diagonal_(False) src, dst = edge_index mask[src, dst] = False loss_nei = -S[src, dst].mean() S_spa = torch.masked_select(S, mask) S_spa = torch.sigmoid((S_spa - self.s) / self.tau) loss_spa = S_spa.mean() loss = loss_ali + self.lam * loss_nei + self.gam * loss_spa return loss, loss_ali, loss_nei, loss_spa
[docs] def loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> LossOutput: r""" Computes the NS4GC loss with multiple components. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices. Returns: LossOutput containing total loss and individual components. """ z1, z2 = self(x, edge_index, **kwargs) loss, loss_ali, loss_nei, loss_spa = self._compute_loss(z1, z2, edge_index) return LossOutput( total=loss, components={ 'ali': loss_ali.item(), 'nei': loss_nei.item(), 'spa': loss_spa.item() } )
[docs] def loss_batch(self, batch: Data) -> LossOutput: r""" Computes loss for a mini-batch with seed node slicing. Args: batch (Data): A mini-batch from the loader. Returns: LossOutput containing total loss and individual components. """ z1, z2 = self(batch.x, batch.edge_index) z1 = z1[:batch.batch_size] z2 = z2[:batch.batch_size] # Extract edges within the batch # input_nodes = torch.arange(batch.batch_size, device=batch.x.device) # batch_edge_index, _ = subgraph(input_nodes, batch.edge_index, edge_attr=None, relabel_nodes=False) batch_mask = (batch.edge_index[0] < batch.batch_size) & (batch.edge_index[1] < batch.batch_size) batch_edge_index = batch.edge_index[:, batch_mask] loss, loss_ali, loss_nei, loss_spa = self._compute_loss(z1, z2, batch_edge_index) return LossOutput( total=loss, components={ 'ali': loss_ali.item(), 'nei': loss_nei.item(), 'spa': loss_spa.item() } )
def __repr__(self) -> str: return f"{self.__class__.__name__}(encoder={self.encoder})"