from typing import Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset
from pyagc.clusters import DinkClusterHead
from pyagc.models.base import ClusteringModel, LossOutput
from pyagc.utils import filter_kwargs
EPS = 1e-15
[docs]class DinkNet(ClusteringModel):
r"""The Dink-Net model from the
`"Dink-Net: Neural Clustering on Large Graphs"
<https://arxiv.org/abs/2305.18405>`_ paper (Liu et al., ICML 2023).
Dink-Net unifies representation learning and clustering optimization via:
1. **Node Discriminate Module**: Learns discriminative features by distinguishing
original vs. augmented nodes
2. **Neural Clustering Module**: Optimizes clustering via dilation (push centers apart)
and shrink (pull nodes to centers) losses
The total loss is:
.. math::
\mathcal{L} = \mathcal{L}_\text{dilation} + \mathcal{L}_\text{shrink} + \alpha \mathcal{L}_\text{discri}
Args:
encoder (torch.nn.Module): The encoder module (e.g., GCN, GraphSAGE).
projector (torch.nn.Module): The projection head for discriminative learning.
n_clusters (int): Number of clusters.
hidden_channels (int): Hidden dimension of node embeddings.
corruption (Callable, optional): Corruption function for data augmentation.
If None, uses random feature shuffling. (default: :obj:`None`)
alpha (float, optional): Trade-off weight for discriminative loss.
(default: :obj:`1e-10`)
"""
def __init__(
self,
encoder: nn.Module,
projector: nn.Module,
n_clusters: int,
hidden_channels: int,
corruption: Optional[Callable] = None,
alpha: float = 1e-10,
):
super(DinkNet, self).__init__()
self.encoder = encoder
self.projector = projector
self.n_clusters = n_clusters
self.hidden_channels = hidden_channels
self.alpha = alpha
# Default corruption: feature shuffle
if corruption is None:
corruption = lambda x, edge_index: (
x[torch.randperm(x.size(0))], edge_index
)
self.corruption = corruption
# Initialize cluster head (will be set after pretraining)
self.cluster_head = DinkClusterHead(n_clusters, self.hidden_channels)
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
reset(self.encoder)
reset(self.projector)
self.cluster_head.reset_cluster_centers()
[docs] def embed(self, x: Tensor, edge_index: Tensor, **kwargs) -> Tensor:
r"""Computes node embeddings."""
return self.encoder(x, edge_index, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def forward(self, x: Tensor, edge_index: Tensor, **kwargs) -> Tensor:
r"""Returns cluster assignments."""
z = self.embed(x, edge_index, **kwargs)
z = F.normalize(z, p=2, dim=-1)
return self.cluster_head.cluster(z, soft=False)
def _discriminate_loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> Tensor:
r"""Computes discriminative loss for node discrimination."""
batch_size = kwargs.get('batch_size', -1)
# Original embeddings
z1 = self.embed(x, edge_index, **kwargs)
z1 = z1[:batch_size] if batch_size > 0 else z1
z1_proj = self.projector(z1)
g1 = z1_proj.sum(dim=-1) # (n_samples,)
# Corrupted embeddings
x_corrupt, edge_index_corrupt = self.corruption(x, edge_index)
z2 = self.embed(x_corrupt, edge_index_corrupt, **kwargs)
z2 = z2[:batch_size] if batch_size > 0 else z2
z2_proj = self.projector(z2)
g2 = z2_proj.sum(dim=-1) # (n_samples,)
# Binary cross-entropy
loss = (F.softplus(-g1) + F.softplus(g2)).mean() / 2
return loss
[docs] def pretrain_loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> Tensor:
r"""Computes pretraining loss (discriminative only)."""
return self._discriminate_loss(x, edge_index, **kwargs)
[docs] def loss(self, x: Tensor, edge_index: Tensor, pretrain: bool = False, **kwargs) -> LossOutput:
r"""
Computes the total loss.
Args:
x (torch.Tensor): Node features.
edge_index (torch.Tensor): Edge indices.
pretrain (bool, optional): If True, only compute discriminative loss.
(default: :obj:`False`)
Returns:
LossOutput containing total loss and individual components.
"""
loss_discri = self._discriminate_loss(x, edge_index, **kwargs)
if pretrain:
return LossOutput(
total=loss_discri,
components={'discri': loss_discri.item()}
)
# Fine-tuning: discriminative + clustering losses
z = self.embed(x, edge_index, **kwargs)
z = F.normalize(z, p=2, dim=-1)
loss_dilation, loss_shrink = self.cluster_head(z)
total_loss = loss_dilation + loss_shrink + self.alpha * loss_discri
return LossOutput(
total=total_loss,
components={
'dilation': loss_dilation.item(),
'shrink': loss_shrink.item(),
'discri': loss_discri.item()
}
)
[docs] def loss_batch(self, batch: Data, pretrain: bool = False) -> LossOutput:
r"""
Computes loss for a mini-batch with seed node slicing.
Args:
batch (Data): A mini-batch from the loader.
pretrain (bool, optional): If True, only compute discriminative loss.
Returns:
LossOutput containing total loss and individual components.
"""
batch_size = batch.batch_size
loss_discri = self._discriminate_loss(batch.x, batch.edge_index, batch_size=batch_size)
if pretrain:
return LossOutput(
total=loss_discri,
components={'discri': loss_discri.item()}
)
# Fine-tuning with batch
z = self.embed(batch.x, batch.edge_index)
z = z[:batch_size] # Only seed nodes
z = F.normalize(z, p=2, dim=-1)
loss_dilation, loss_shrink = self.cluster_head(z)
total_loss = loss_dilation + loss_shrink + self.alpha * loss_discri
return LossOutput(
total=total_loss,
components={
'dilation': loss_dilation.item(),
'shrink': loss_shrink.item(),
'discri': loss_discri.item()
}
)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'n_clusters={self.n_clusters}, '
f'alpha={self.alpha})')