from typing import Any
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset
from pyagc.clusters import MinCutClusterHead
from pyagc.models.base import ClusteringModel, LossOutput
from pyagc.utils import filter_kwargs
[docs]class MinCut(ClusteringModel):
r"""
The MinCut model is based on
`"Spectral Clustering in Graph Neural Networks for Graph Pooling"
<https://arxiv.org/abs/1907.00481>`_ (Bianchi et al., 2019).
It performs **unsupervised graph clustering** by coupling a graph encoder
(e.g., GCN, GraphSAGE) with the :class:`~pyagc.clusters.MinCutClusterHead`.
The encoder produces node embeddings :math:`\mathbf{Z}`, which are projected
into soft cluster assignments :math:`\mathbf{S}`. The model jointly
optimizes the MinCut objective and the orthogonality regularizer to yield
compact, well-separated clusters.
The optimization objective consists of two losses:
**(1) MinCut loss:**
.. math::
\mathcal{L}_{\text{mincut}} = -
\frac{\mathrm{Tr}(\mathbf{S}^\top \mathbf{A} \mathbf{S})}
{\mathrm{Tr}(\mathbf{S}^\top \mathbf{D} \mathbf{S})}
encouraging large within-cluster connectivity relative to cluster volume.
**(2) Orthogonality regularization:**
.. math::
\mathcal{L}_{\text{ortho}} =
\left\| \frac{\mathbf{S}^\top \mathbf{S}}{\|\mathbf{S}^\top \mathbf{S}\|_F}
- \frac{\mathbf{I}_K}{\sqrt{K}} \right\|_F
encouraging near-orthogonal cluster assignment columns to avoid collapse.
The final training objective is a weighted combination of the two terms:
.. math::
\mathcal{L} = \mathcal{L}_{\text{mincut}} + \lambda \mathcal{L}_{\text{ortho}}
where :math:`\lambda` controls the strength of the orthogonality regularization.
Args:
encoder (torch.nn.Module): Node encoder that outputs node embeddings.
n_features (int): Feature dimension of the encoder outputs.
n_clusters (int): Number of clusters.
lam (float, optional): Regularization coefficient for the
orthogonality loss :math:`\mathcal{L}_{\text{ortho}}`. (default: :obj:`1.0`)
temperature (float, optional): Softmax temperature used in the MinCut head. (default: :obj:`1.0`)
"""
def __init__(
self,
encoder: nn.Module,
n_features: int,
n_clusters: int,
lam: float = 1.0,
temperature: float = 1.0,
):
super().__init__()
self.encoder = encoder
self.n_features = n_features
self.n_clusters = n_clusters
self.lam = lam
# MinCut clustering head
self.head = MinCutClusterHead(n_clusters=n_clusters, n_features=n_features, temperature=temperature)
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
reset(self.encoder)
self.head.reset_cluster_centers()
[docs] def embed(self, *args, **kwargs) -> Tensor:
r"""Compute node embeddings via the encoder."""
return self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def forward(self, *args, **kwargs) -> Tensor:
r"""Predict hard cluster assignments from current parameters."""
z = self.embed(*args, **kwargs)
return self.head.cluster(z)
[docs] def loss(self, x: Tensor, edge_index: Tensor, **kwargs: Any) -> LossOutput:
r"""
Computes the MinCut loss with multiple components.
Args:
x (torch.Tensor): Node features.
edge_index (torch.Tensor): Edge indices.
Returns:
LossOutput containing total loss and individual components.
"""
z = self.embed(x, edge_index, **kwargs)
loss_mincut, loss_ortho = self.head(z=z, edge_index=edge_index)
loss = loss_mincut + self.lam * loss_ortho
return LossOutput(
total=loss,
components={
'mincut': loss_mincut.item(),
'ortho': loss_ortho.item()
}
)
[docs] def loss_batch(self, batch: Data, **kwargs: Any):
r"""MinCut currently does not support mini-batch training."""
raise NotImplementedError(f"{self.__class__.__name__} does not support batch training.")
def __repr__(self):
return (f"{self.__class__.__name__}(lam={self.lam}, "
f"encoder={self.encoder.__class__.__name__})")