Source code for pyagc.clusters.mincut_cluster_head

from typing import Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.utils import degree

from pyagc.clusters import BaseClusterHead

EPS = 1e-15


[docs]class MinCutClusterHead(BaseClusterHead): r""" MinCut Clustering Head proposed in `"Spectral Clustering in Graph Neural Networks for Graph Pooling" <https://arxiv.org/abs/1907.00481>`_ (Bianchi et al., ICML 2019). This layer learns a **soft cluster assignment matrix** :math:`\mathbf{S}` by projecting node embeddings :math:`\mathbf{Z}` into :math:`K` clusters. It jointly optimizes two objectives: **(1) MinCut loss:** .. math:: \mathcal{L}_{\text{mincut}} = - \frac{\mathrm{Tr}(\mathbf{S}^\top \mathbf{A} \mathbf{S})} {\mathrm{Tr}(\mathbf{S}^\top \mathbf{D} \mathbf{S})} where :math:`\mathbf{D}` is the degree matrix. **(2) Orthogonality loss:** .. math:: \mathcal{L}_{\text{ortho}} = \left\| \frac{\mathbf{S}^\top \mathbf{S}}{\|\mathbf{S}^\top \mathbf{S}\|_F} - \frac{\mathbf{I}_K}{\sqrt{K}} \right\|_F which encourages near-orthogonal cluster assignments. Args: n_clusters (int): Number of clusters :math:`K`. n_features (int): Feature dimension of node embeddings :math:`F`. temperature (float, optional): Softmax temperature. (default: 1.0) """
[docs] def __init__(self, n_clusters: int, n_features: int, temperature: float = 1.0): super(MinCutClusterHead, self).__init__() self.n_clusters = n_clusters self.n_features = n_features self.temperature = temperature self.cluster_centers = nn.Parameter(torch.empty(n_clusters, n_features)) self.reset_cluster_centers() # Buffer for pre-computed graph structures self.register_buffer('adj_sparse', None, persistent=False) self.register_buffer('deg', None, persistent=False)
[docs] def reset_cluster_centers(self, cluster_centers: Tensor = None) -> None: r""" Manually sets the cluster centers. Args: cluster_centers (torch.Tensor, optional): Tensor of shape :obj:`(n_clusters, n_features)` to initialize the cluster centers. If None, use Xavier uniform initialization. """ if cluster_centers is not None: assert cluster_centers.shape == (self.n_clusters, self.n_features) with torch.no_grad(): self.cluster_centers.copy_(cluster_centers) else: nn.init.xavier_uniform_(self.cluster_centers)
def _prepare_graph(self, edge_index: Tensor, N: int): """Pre-computes and caches sparse adjacency and degree statistics.""" if self.adj_sparse is not None and self.adj_sparse.size(0) == N: return device = edge_index.device # 1. Pre-compute degree deg = degree(edge_index[0], N, dtype=torch.float) # 2. Pre-compute Sparse COO Adjacency Matrix # Using coalesce() is vital for optimized sparse-dense matmul val = torch.ones(edge_index.size(1), device=device) adj = torch.sparse_coo_tensor(edge_index, val, (N, N)).coalesce() self.adj_sparse = adj self.deg = deg
[docs] def forward(self, z: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor]: r""" Compute MinCut and Orthogonality losses given node embeddings and graph structure. Args: z (torch.Tensor): Node embeddings :math:`(N, F)`. edge_index (torch.Tensor): Edge indices :math:`(2, E)`. Returns: Tuple[torch.Tensor, torch.Tensor]: mincut_loss and ortho_loss """ N = z.size(0) K = self.n_clusters # === Step 0: Cache graph === self._prepare_graph(edge_index, N) # === Step 1: Soft assignment === S = torch.matmul(z, self.cluster_centers.t()) # (N, K) S = torch.softmax(S / self.temperature, dim=-1) # === Step 2: Compute MinCut Loss === # Tr(S^T A S) using sparse multiplication # AS calculation: (N, N) @ (N, K) -> (N, K) AS = torch.sparse.mm(self.adj_sparse, S) # Tr(S^T AS) = sum(S * AS) element-wise product followed by sum tr_SAS = torch.sum(S * AS) # Tr(S^T D S) where D is diagonal degree matrix # = sum_i d_i * (S_i^T S_i) = sum_i d_i * ||S_i||^2 # More efficiently: deg^T (S ⊙ S) where ⊙ is element-wise product # = sum over all elements of: deg[:, None] * S * S S_squared = S * S # (N, K) tr_SDS = torch.sum(self.deg.view(-1, 1) * S_squared) mincut_loss = -tr_SAS / (tr_SDS + EPS) # === Step 3: Orthogonality Loss === # S^T S is (K, K) StS = torch.matmul(S.t(), S) # # Normalize by Frobenius norm # StS_norm = StS / (torch.norm(StS, p='fro') + EPS) # # Target: I_K / sqrt(K) # I_normalized = torch.eye(K, device=z.device, dtype=z.dtype) / (K ** 0.5) # ortho_loss = torch.norm(StS_norm - I_normalized, p='fro') StS_norm = torch.norm(StS, p='fro') trace_StS = torch.trace(StS) ortho_loss = 2.0 * (1.0 - trace_StS / (StS_norm * (K ** 0.5) + EPS)) return mincut_loss, ortho_loss
[docs] @torch.no_grad() def cluster(self, z: Tensor, soft: bool = False) -> Tensor: r""" Predict cluster assignments. Args: z (torch.Tensor): Node embeddings of shape (N, F). soft (bool, optional): If True, return soft assignment probabilities. Returns: torch.Tensor: Hard cluster indices or soft assignment matrix. """ sim = torch.matmul(z, self.cluster_centers.t()) if soft: return (sim / self.temperature).softmax(dim=-1) else: return sim.argmax(dim=-1)