Source code for pyagc.clusters.torch_kmeans

import tqdm
import torch
from torch import Tensor
import torch.nn.functional as F
import torch.distributed as dist
import warnings as _warnings
from typing import Callable, Optional, Tuple, Union


def _distributed_sync(tensor: Tensor) -> Tensor:
    r"""Synchronizes tensors across all distributed workers."""
    tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
    dist.all_gather(tensors_gather, tensor, async_op=False)
    return torch.stack(tensors_gather)


def _pairwise_cosine(x1: Tensor, x2: Tensor, pairwise: bool = True) -> Tensor:
    r"""Computes pairwise Cosine distances."""
    x1 = F.normalize(x1)
    x2 = F.normalize(x2)
    if not pairwise:
        return 1 - (x1 * x2).sum(dim=-1)
    return 1 - x1.mm(x2.T)

def _pairwise_dot(x1: Tensor, x2: Tensor, pairwise: bool = True) -> Tensor:
    r"""Computes pairwise Dot distances."""
    if not pairwise:
        return - (x1 * x2).sum(dim=-1)
    return - x1.mm(x2.T)

def _pairwise_euclidean(x1: Tensor, x2: Tensor, pairwise: bool = True) -> Tensor:
    r"""Computes pairwise Euclidean distances."""
    if not pairwise:
        return (x1 - x2).pow(2).sum(dim=-1).sqrt()
    return torch.cdist(x1, x2, p=2.)


def _stable_cumsum(arr: Tensor, dim: Optional[int] = None, rtol=1e-05, atol=1e-08) -> Tensor:
    r"""Performs a numerically stable cumulative sum."""
    if dim is None:
        arr = arr.flatten()
        dim = 0
    out = torch.cumsum(arr, dim=dim, dtype=torch.float64)
    expected = torch.sum(arr, dim=dim, dtype=torch.float64)
    if not torch.all(torch.isclose(out[-1], expected, rtol=rtol, atol=atol, equal_nan=True)):
        _warnings.warn('cumsum was found to be unstable: its last element does not correspond to sum',
                       RuntimeWarning)
    return out


def _kmeans_plusplus(X: Tensor, n_clusters: int, random_state: int, pairwise_distance: Callable,
                     n_local_trials: Optional[int] = None) -> Tuple[Tensor, Tensor]:
    r"""Computational component for k-means++ initialization."""
    n_samples, n_features = X.size()

    generator = torch.Generator(device=str(X.device)).manual_seed(random_state)
    centers = torch.empty((n_clusters, n_features), dtype=X.dtype, device=X.device)

    # Set the number of local seeding trials if none is given
    if n_local_trials is None:
        # This is what Arthur/Vassilvitskii tried, but did not report
        # specific results for other than mentioning in the conclusion
        # that it helped.
        n_local_trials = 2 + int(torch.log(torch.tensor(n_clusters)).item())

    # Pick first center randomly and track index of point
    #     center_id = random_state.randint(n_samples)
    center_id = torch.randint(n_samples, (1,), generator=generator, device=X.device)

    indices = torch.full((n_clusters,), -1, dtype=torch.int, device=X.device)
    centers[0] = X[center_id]
    indices[0] = center_id

    # Initialize list of closest distances and calculate current potential
    closest_dist_sq = pairwise_distance(centers[0, None], X)
    current_pot = closest_dist_sq.sum()

    # Pick the remaining n_clusters-1 points
    for c in range(1, n_clusters):
        # Choose center candidates by sampling with probability proportional
        # to the squared distance to the closest existing center
        #         rand_vals = random_state.random_sample(n_local_trials) * current_pot
        rand_vals = torch.rand(n_local_trials, generator=generator, device=X.device) * current_pot
        candidate_ids = torch.searchsorted(_stable_cumsum(closest_dist_sq), rand_vals)
        # XXX: numerical imprecision can result in a candidate_id out of range
        candidate_ids.clamp_(max=closest_dist_sq.numel() - 1)

        # Compute distances to center candidates
        distance_to_candidates = pairwise_distance(X[candidate_ids], X)

        # Update closest distances squared and potential for each candidate
        torch.minimum(closest_dist_sq, distance_to_candidates, out=distance_to_candidates)
        candidates_pot = distance_to_candidates.sum(dim=-1)

        # Decide which candidate is the best
        best_candidate = torch.argmin(candidates_pot)
        current_pot = candidates_pot[best_candidate]
        closest_dist_sq = distance_to_candidates[best_candidate]
        best_candidate = candidate_ids[best_candidate]

        # Permanently add best center candidate found in local tries
        centers[c] = X[best_candidate]
        indices[c] = best_candidate

    return centers, indices


[docs]class TorchKMeans: r"""A PyTorch-based KMeans clustering implementation supporting both Euclidean and Cosine distance metrics, with optional distributed training. This implementation is adapted from: `Hzzone/torch_clustering <https://github.com/Hzzone/torch_clustering>`_. Args: metric (str, optional): Distance metric to use: ``'euclidean'`` or ``'cosine'``. (default: ``'euclidean'``) init (str or torch.Tensor, optional): Method for initialization: ``'k-means++'``, ``'random'`` or user-specified tensor of shape :obj:`(n_clusters, n_features)`. (default: ``'k-means++'``) random_state (int, optional): Random seed for initialization. (default: ``None``) n_clusters (int, optional): Number of clusters. (default: ``8``) n_init (int, optional): Number of times the algorithm will be run with different centroid seeds. (default: ``10``) max_iter (int, optional): Maximum number of iterations of the k-means algorithm for a single run. (default: ``300``) tol (float, optional): Relative tolerance with regards to inertia to declare convergence. (default: ``1e-4``) distributed (bool, optional): Whether to use distributed training. (default: ``False``) verbose (bool, optional): Whether to print progress information. (default: ``False``) """
[docs] @torch.no_grad() def __init__(self, metric: str = 'euclidean', init: Union[str, Tensor] = 'k-means++', random_state: Optional[int] = None, n_clusters: int = 8, n_init: int = 10, max_iter: int = 300, tol: float = 1e-4, distributed: bool = False, verbose: bool = False): self.metric = metric.lower() if metric not in {'euclidean', 'cosine'}: raise ValueError( 'Invalid metric value. Must be either "euclidean" or "cosine".' ' But got "{}".'.format(metric) ) self.distance_metric = {'euclidean': _pairwise_euclidean, 'cosine': _pairwise_cosine}[metric] self.n_clusters = n_clusters self.n_init = n_init self.max_iter = max_iter self.tol = tol self.cluster_centers_: Optional[Tensor] = None self.init = init if isinstance(self.init, torch.Tensor): self.n_init = 1 if random_state is None: random_state = 0 self.random_state = random_state self.is_root_worker = not dist.is_initialized() or dist.get_rank() == 0 self.verbose = verbose and self.is_root_worker self.distributed = distributed and dist.is_initialized() if self.verbose and self.distributed: print('Perform K-means in distributed mode.') self.world_size = dist.get_world_size() if self.distributed else 1 self.rank = dist.get_rank() if self.distributed else 0
[docs] @torch.no_grad() def initialize(self, X: Tensor, random_state: int) -> Tensor: r"""Initializes the cluster centers. Args: X (torch.Tensor): The input data of shape :obj:`(n_samples, n_features)`. random_state (int): The random seed. Returns: Initialized cluster centers of shape :obj:`(n_clusters, n_features)`. """ num_samples = X.size(0) if isinstance(self.init, str): generator = torch.Generator().manual_seed(random_state) if self.init == 'random': indices = torch.randperm(num_samples, generator=generator)[:self.n_clusters] init_state = X[indices] elif self.init == 'k-means++': init_state, _ = _kmeans_plusplus(X, n_clusters=self.n_clusters, random_state=random_state, pairwise_distance=self.distance_metric) else: raise NotImplementedError(f"Unknown init method: {self.init}") elif isinstance(self.init, Tensor): init_state = self.init.to(X) else: raise NotImplementedError return init_state
[docs] @torch.no_grad() def fit_predict(self, X: Tensor) -> Tensor: r"""Performs k-means clustering on the input data and returns cluster labels. Args: X (torch.Tensor): The input data of shape :obj:`(n_samples, n_features)`. Returns: Cluster assignments of shape :obj:`(n_samples,)`. """ tol = torch.mean(torch.var(X, dim=0)) * self.tol min_inertia, best_states, best_labels = float('inf'), None, None random_states = torch.arange(self.n_init * self.world_size) + self.random_state random_states = random_states[self.rank::self.world_size] self.stats = {'state': [], 'inertia': [], 'label': []} for n_init in range(self.n_init): random_state = int(random_states[n_init]) old_state = self.initialize(X, random_state=random_state) old_labels, inertia = self._predict(X, old_state) labels = old_labels progress_bar = tqdm.tqdm(total=self.max_iter, disable=not self.verbose) for n_iter in range(self.max_iter): # Compute new cluster centers state = torch.zeros_like(old_state) counts = torch.zeros(self.n_clusters, dtype=X.dtype, device=X.device) + 1e-6 counts.index_add_(0, labels, torch.ones_like(labels, dtype=X.dtype)) state.index_add_(0, labels, X) state = state / counts.view(-1, 1) # Compute new labels and inertia labels, inertia = self._predict(X, state) if inertia < min_inertia: min_inertia = inertia best_states, best_labels = state, labels if self.verbose: progress_bar.set_description( f'nredo {n_init + 1}/{self.n_init:02d}, iteration {n_iter:03d} with inertia {inertia:.2f}') progress_bar.update(1) if torch.equal(labels, old_labels): if self.verbose: print(f"Converged at iteration {n_iter}: strict convergence.") break else: center_shift_tot = self.distance_metric(old_state, state, pairwise=False).sum() if center_shift_tot <= tol: if self.verbose: print(f"Converged at iteration {n_iter}: center shift " f"{center_shift_tot:.2e} within tolerance {tol:.2e}.") break old_labels[:] = labels old_state = state progress_bar.close() self.stats['state'].append(old_state) self.stats['inertia'].append(inertia) self.stats['label'].append(old_labels) self.stats['state'] = torch.stack(self.stats['state']) self.stats['inertia'] = torch.tensor(self.stats['inertia']) self.stats['label'] = torch.stack(self.stats['label']) if self.distributed: min_inertia = _distributed_sync(torch.tensor(min_inertia)) best_idx = torch.argmin(min_inertia).item() dist.broadcast(best_labels, src=best_idx) dist.broadcast(best_states, src=best_idx) self.stats['state'] = _distributed_sync(self.stats['state']) self.stats['inertia'] = _distributed_sync(self.stats['inertia']) self.stats['label'] = _distributed_sync(self.stats['label']) if self.verbose: print(f"Final min inertia {min_inertia.item():.2f}.") self.cluster_centers_ = best_states return best_labels
@torch.no_grad() def _predict(self, X: Tensor, cluster_centers_: Tensor = None) -> Tuple[Tensor, float]: r"""Assigns each sample in :obj:`X` to the nearest cluster center. Args: X (torch.Tensor): Input data of shape :obj:`(n_samples, n_features)`. cluster_centers_ (torch.Tensor, optional): Precomputed cluster centers. If :obj:`None`, uses :obj:`self.cluster_centers_`. Returns: 1. Cluster labels of shape :obj:`(num_nodes,)`. 2. Total inertia (float scalar). """ if cluster_centers_ is None: cluster_centers_ = self.cluster_centers_ dist_mat = self.distance_metric(X, cluster_centers_) dists, labels = dist_mat.min(dim=1) inertia = dists.sum().item() return labels, inertia # split_size = min(4096, X.size(0)) # all_labels = [] # inertia = 0.0 # # for chunk in X.split(split_size, dim=0): # dist_mat = self.distance_metric(chunk, cluster_centers_) # dists, labels = dist_mat.min(dim=1) # inertia += dists.sum().item() # all_labels.append(labels) # # return torch.cat(all_labels, dim=0), inertia
[docs] @torch.no_grad() def predict(self, X: Tensor, soft: bool = False) -> Tensor: r"""Assigns samples to clusters based on fixed cluster centers. This function computes the squared Euclidean distance to each center and returns either hard assignments or soft probabilities. Args: X (torch.Tensor): Input tensor of shape :obj:`(n_samples, n_features)`. soft (bool, optional): If True, returns the soft assignment matrix; if False, returns hard cluster assignments. (default: :obj:`False`) Returns: - If :obj:`soft` is False, :obj:`(n_samples,)` tensor of cluster indices. - If :obj:`soft` is True, :obj:`(n_samples, n_clusters)` tensor of probabilities. """ if self.cluster_centers_ is None or self.cluster_centers_.numel() == 0: raise RuntimeError("Must call `fit_predict` before using `cluster`.") dists = self.distance_metric(X, self.cluster_centers_) # (n_samples, n_clusters) if soft: return (-dists.sqrt()).softmax(dim=-1) # smaller distance => higher score else: return dists.argmin(dim=-1) # assign to nearest cluster center
def __repr__(self) -> str: """String representation of the TorchKMeans object.""" return ( f"TritonKMeans(metric={self.metric!r}, " f"init={self.init!r}, " f"n_clusters={self.n_clusters}, " f"n_init={self.n_init}, " f"max_iter={self.max_iter}, " f"tol={self.tol}, " f"random_state={self.random_state}, " f"verbose={self.verbose})" )
if __name__ == '__main__': clustering_model = TorchKMeans(metric='euclidean', init='k-means++', random_state=0, n_clusters=1000, n_init=10, max_iter=300, tol=1e-4, distributed=False, verbose=True) X = torch.randn(1280, 16) clustering_model.fit_predict(X)