Source code for pyagc.clusters.inc_cluster_head

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

from pyagc.clusters import BaseClusterHead


[docs]class INCClusterHead(BaseClusterHead): r""" Neural Clustering Layer proposed in the `"XAI Beyond Classification: Interpretable Neural Clustering" <https://jmlr.org/papers/v23/19-497.html>`_ paper (Peng et al., JMLR 2022). This layer implements a differentiable k-means reformulation, where cluster centers are learnable parameters. The cluster centers :math:`\mathbf{\Omega}_j` and weights :math:`\mathbf{W}_j` are related via: .. math:: \mathbf{W}_j = 2 \alpha \mathbf{\Omega}_j The loss is defined as: .. math:: L = \frac{1}{\alpha} \sum_{i} (2\alpha - \mathbf{W}_{c(i)}^T \mathbf{z}_i) where :math:`c(i)` is the nearest cluster to sample :math:`i`. Args: n_clusters (int): Number of clusters. n_features (int): Feature dimension of the input. alpha (float): Scaling factor for cluster centers. Default is 0.001. """
[docs] def __init__(self, n_clusters: int, n_features: int, alpha: float = 0.001): super(INCClusterHead, self).__init__() self.n_clusters = n_clusters self.n_features = n_features self.alpha = alpha # Cluster centers (normalized to unit norm) self.cluster_centers = nn.Parameter(torch.empty(n_clusters, n_features)) self.reset_cluster_centers()
[docs] def reset_cluster_centers(self, cluster_centers: Tensor = None) -> None: r""" Manually sets the cluster centers. Args: cluster_centers (torch.Tensor, optional): Tensor of shape (n_clusters, n_features) to initialize the cluster centers. If None, use Xavier uniform initialization. """ if cluster_centers is not None: assert cluster_centers.shape == (self.n_clusters, self.n_features) with torch.no_grad(): self.cluster_centers.copy_(F.normalize(cluster_centers, dim=-1)) else: nn.init.xavier_uniform_(self.cluster_centers) with torch.no_grad(): self.cluster_centers.copy_(F.normalize(self.cluster_centers, dim=-1))
[docs] def normalize_cluster_centers(self) -> None: r""" Normalize cluster centers and apply alpha scaling. """ with torch.no_grad(): normalized = F.normalize(self.cluster_centers, dim=-1) self.cluster_centers.copy_(normalized)
[docs] def compute_cluster_center(self) -> Tensor: r""" Compute the actual cluster centers (scaled by alpha). Returns: Scaled cluster centers. """ return (0.5 / self.alpha) * self._get_weight()
def _get_weight(self) -> Tensor: r""" Compute scaled cluster weight. Returns: Weight tensor. """ return 2.0 * self.alpha * F.normalize(self.cluster_centers, dim=-1)
[docs] def forward(self, z: Tensor) -> Tensor: r""" Compute the clustering loss. Args: z (torch.Tensor): Input tensor of shape :obj:`(n_samples, n_features)`. Returns: Scalar loss tensor. """ z = F.normalize(z, dim=-1) W = self._get_weight() similarity = F.linear(z, W) assignments = similarity.argmax(dim=-1) selected_similarity = similarity.gather(1, assignments.unsqueeze(1)).squeeze(1) loss = (1. / self.alpha) * (2.0 * self.alpha - selected_similarity).mean() return loss
[docs] @torch.no_grad() def cluster(self, z: Tensor, soft: bool = False) -> Tensor: r""" Predicts cluster assignments. Args: z (torch.Tensor): Input tensor of shape :obj:`(n_samples, n_features)`. soft (bool, optional): If True, returns the soft assignment matrix; if False, returns hard cluster assignments. (default: :obj:`False`) Returns: - If :obj:`soft` is False, :obj:`(n_samples,)` tensor of cluster indices. - If :obj:`soft` is True, :obj:`(n_samples, n_clusters)` tensor of probabilities. """ z = F.normalize(z, dim=-1) W = self._get_weight() similarity = F.linear(z, W) if soft: return similarity.softmax(dim=-1) else: return similarity.argmax(dim=-1)
[docs] def gradient_normalize(self) -> None: r""" Normalize the gradient of the cluster centers. """ if self.cluster_centers.grad is not None: with torch.no_grad(): normed_grad = F.normalize(self.cluster_centers.grad, dim=-1) self.cluster_centers.grad.copy_(normed_grad * (0.2 * self.alpha))