Source code for pyagc.clusters.triton_kmeans

# Adapted from https://github.com/svg-project/flash-kmeans

from typing import Optional, Union, Tuple

import torch
import torch.nn.functional as F

try:
    import triton
    import triton.language as tl
    _has_triton = True
except ImportError:
    _has_triton = False
    triton = None
    tl = None


def _ceil_div(a: int, b: int) -> int:
    return (a + b - 1) // b


# ============================================================
# All Triton kernels and related helpers are only defined
# when Triton is available.
# ============================================================
if _has_triton:
    import tqdm

    _TUNE_CONFIGS = [
        triton.Config({"BLOCK_N": BN, "BLOCK_K": BK}, num_stages=num_stages, num_warps=wp)
        for BN in [32, 64, 128]
        for BK in [32, 64, 128]
        for wp in [4, 8]
        for num_stages in [1, 2, 4]
    ]


    def _cfg_keep(conf):
        """Basic heuristic to prune unbalanced configs."""
        BN = conf.kwargs["BLOCK_N"]
        BK = conf.kwargs["BLOCK_K"]
        if BN * BK < 32 * 32 and conf.num_warps > 4:
            return False
        return True


    _TUNE_CONFIGS = list(filter(_cfg_keep, _TUNE_CONFIGS))


    @triton.autotune(_TUNE_CONFIGS, key=["N", "K"])
    @triton.jit
    def _euclid_assign_kernel(
            x_ptr,  # *f16 / *f32 [N, D]
            c_ptr,  # *f16 / *f32 [K, D]
            x_sq_ptr,  # *f32         [N]
            c_sq_ptr,  # *f32         [K]
            out_ptr,  # *i32         [N]
            N: tl.constexpr,
            K: tl.constexpr,
            D: tl.constexpr,
            stride_x_n: tl.constexpr,
            stride_x_d: tl.constexpr,
            stride_c_k: tl.constexpr,
            stride_c_d: tl.constexpr,
            BLOCK_N: tl.constexpr,
            BLOCK_K: tl.constexpr,
    ):
        """Each program handles a tile of BLOCK_N points."""
        pid_n = tl.program_id(0)

        n_start = pid_n * BLOCK_N
        n_offsets = n_start + tl.arange(0, BLOCK_N)
        n_mask = n_offsets < N

        # Load x tile  (BLOCK_N, D)
        offs_d = tl.arange(0, D)
        x_ptrs = x_ptr + n_offsets[:, None] * stride_x_n + offs_d[None, :] * stride_x_d
        x_tile = tl.load(x_ptrs, mask=n_mask[:, None], other=0.0)
        x_tile = x_tile.to(tl.float32)

        # Pre-load x_sq for the tile  (BLOCK_N,)
        xsq_ptrs = x_sq_ptr + n_offsets
        x_sq_tile = tl.load(xsq_ptrs, mask=n_mask, other=0.0).to(tl.float32)

        # Init best distance / index
        best_dist = tl.full((BLOCK_N,), 3.4e38, tl.float32)
        best_idx = tl.zeros((BLOCK_N,), tl.int32)

        # Iterate over centroids in chunks of BLOCK_K
        for k_start in range(0, K, BLOCK_K):
            k_offsets = k_start + tl.arange(0, BLOCK_K)
            k_mask = k_offsets < K

            # Load centroid tile  (D, BLOCK_K)
            c_ptrs = c_ptr + k_offsets[None, :] * stride_c_k + offs_d[:, None] * stride_c_d
            c_tile = tl.load(c_ptrs, mask=k_mask[None, :], other=0.0)
            c_tile = c_tile.to(tl.float32)

            # load c_sq for the tile  (BLOCK_K,)
            csq_ptrs = c_sq_ptr + k_offsets
            cent_sq = tl.load(csq_ptrs, mask=k_mask, other=0.0).to(tl.float32)

            # Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile
            cross = tl.dot(x_tile, c_tile).to(tl.float32)

            # Squared Euclidean distance
            dist = x_sq_tile[:, None] + cent_sq[None, :] - 2.0 * cross
            dist = tl.maximum(dist, 0.0)

            # Mask out invalid centroid columns
            dist = tl.where(k_mask[None, :], dist, 3.4e38)

            curr_min = tl.min(dist, axis=1)
            curr_idx = tl.argmin(dist, axis=1)

            update = curr_min < best_dist
            best_dist = tl.where(update, curr_min, best_dist)
            best_idx = tl.where(update, k_start + curr_idx, best_idx)

        # Write results
        out_ptrs = out_ptr + n_offsets
        tl.store(out_ptrs, best_idx, mask=n_mask)


    @triton.autotune(_TUNE_CONFIGS, key=["N", "K"])
    @triton.jit
    def _cosine_assign_kernel(
            x_ptr,  # *f16 / *f32 [N, D]
            c_ptr,  # *f16 / *f32 [K, D]
            out_ptr,  # *i32         [N]
            N: tl.constexpr,
            K: tl.constexpr,
            D: tl.constexpr,
            stride_x_n: tl.constexpr,
            stride_x_d: tl.constexpr,
            stride_c_k: tl.constexpr,
            stride_c_d: tl.constexpr,
            BLOCK_N: tl.constexpr,
            BLOCK_K: tl.constexpr,
    ):
        """Each program handles a tile of BLOCK_N points for cosine similarity."""
        pid_n = tl.program_id(0)

        n_start = pid_n * BLOCK_N
        n_offsets = n_start + tl.arange(0, BLOCK_N)
        n_mask = n_offsets < N

        # Load x tile  (BLOCK_N, D)
        offs_d = tl.arange(0, D)
        x_ptrs = x_ptr + n_offsets[:, None] * stride_x_n + offs_d[None, :] * stride_x_d
        x_tile = tl.load(x_ptrs, mask=n_mask[:, None], other=0.0)
        x_tile = x_tile.to(tl.float32)

        # Init best distance / index
        best_dist = tl.full((BLOCK_N,), -3.4e38, tl.float32)
        best_idx = tl.zeros((BLOCK_N,), tl.int32)

        # Iterate over centroids in chunks of BLOCK_K
        for k_start in range(0, K, BLOCK_K):
            k_offsets = k_start + tl.arange(0, BLOCK_K)
            k_mask = k_offsets < K

            # Load centroid tile  (D, BLOCK_K)
            c_ptrs = c_ptr + k_offsets[None, :] * stride_c_k + offs_d[:, None] * stride_c_d
            c_tile = tl.load(c_ptrs, mask=k_mask[None, :], other=0.0)
            c_tile = c_tile.to(tl.float32)

            # Compute cosine distance (BLOCK_N, BLOCK_K) = x_tile @ c_tile
            cross = tl.dot(x_tile, c_tile).to(tl.float32)

            # Mask out invalid centroid columns
            dist = tl.where(k_mask[None, :], cross, 0.0)

            curr_max = tl.max(dist, axis=1)
            curr_idx = tl.argmax(dist, axis=1)

            update = curr_max > best_dist
            best_dist = tl.where(update, curr_max, best_dist)
            best_idx = tl.where(update, k_start + curr_idx, best_idx)

        # Write results
        out_ptrs = out_ptr + n_offsets
        tl.store(out_ptrs, best_idx, mask=n_mask)


    def euclid_assign_triton(x: torch.Tensor, centroids: torch.Tensor, x_sq: torch.Tensor = None,
                             out: torch.Tensor = None, c_sq: torch.Tensor = None) -> torch.Tensor:
        """Return nearest-centroid indices using Triton kernel.

        Args:
            x         : (N, D) float16 / float32 (on CUDA)
            centroids : (K, D) same dtype/device as x
            x_sq      : (N,)   float32 – ||x||^2 per point (optional)
            out       : (N,)   int32   – pre-allocated output tensor (optional)
            c_sq      : (K,)   float32 – ||centroids||^2 per centroid (optional)

        Returns:
            cluster_ids (N,) int32
        """
        assert x.is_cuda and centroids.is_cuda, "All tensors must be on CUDA"
        assert centroids.dtype == x.dtype, "centroids dtype mismatch"
        assert x.ndim == 2 and centroids.ndim == 2, "Expected 2D tensors"

        N, D = x.shape
        K, D2 = centroids.shape
        assert D == D2, "Feature dimension mismatch"

        if x_sq is None:
            x_sq = (x.to(torch.float32) ** 2).sum(-1)
        assert x_sq.shape == (N,), "x_sq shape mismatch"

        if out is None:
            out = torch.empty(N, device=x.device, dtype=torch.int32)

        if c_sq is None:
            c_sq = (centroids.to(torch.float32) ** 2).sum(-1)

        stride_x_n, stride_x_d = x.stride()
        stride_c_k, stride_c_d = centroids.stride()

        grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)

        _euclid_assign_kernel[grid](
            x, centroids, x_sq, c_sq, out,
            N, K, D,
            stride_x_n, stride_x_d,
            stride_c_k, stride_c_d,
        )
        return out


    def cosine_assign_triton(x: torch.Tensor, centroids: torch.Tensor,
                             out: torch.Tensor = None) -> torch.Tensor:
        """Return nearest(cosine similarity)-centroid indices using Triton kernel.

        Args:
            x         : (N, D) float16 / float32 (on CUDA)
            centroids : (K, D) same dtype/device as x
            out       : (N,)   int32   – pre-allocated output tensor (optional)

        Returns:
            cluster_ids (N,) int32
        """
        assert x.is_cuda and centroids.is_cuda, "All tensors must be on CUDA"
        assert centroids.dtype == x.dtype, "centroids dtype mismatch"
        assert x.ndim == 2 and centroids.ndim == 2, "Expected 2D tensors"

        N, D = x.shape
        K, D2 = centroids.shape
        assert D == D2, "Feature dimension mismatch"

        if out is None:
            out = torch.empty(N, device=x.device, dtype=torch.int32)

        stride_x_n, stride_x_d = x.stride()
        stride_c_k, stride_c_d = centroids.stride()

        grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)

        _cosine_assign_kernel[grid](
            x, centroids, out,
            N, K, D,
            stride_x_n, stride_x_d,
            stride_c_k, stride_c_d,
        )
        return out


    @triton.jit
    def _centroid_update_kernel(
            x_ptr,  # *f16 / *f32 [N, D]
            cluster_ptr,  # *i32        [N]
            sum_ptr,  # *f32        [K, D]
            count_ptr,  # *i32        [K]
            # --- strides (elements) ---
            stride_x_n, stride_x_d,
            stride_sum_k, stride_sum_d,
            N: tl.constexpr,
            D: tl.constexpr,
            K: tl.constexpr,
            BLOCK_D: tl.constexpr,
    ):
        """Each program processes 1 token across BLOCK_D dims using atomics."""
        pid = tl.program_id(axis=0)
        token_idx = pid

        if token_idx >= N:
            return

        # pointer to this token's feature vector
        x_offset = token_idx * stride_x_n
        x_tok_ptr = x_ptr + x_offset

        cluster_idx = tl.load(cluster_ptr + token_idx)
        cluster_idx = tl.where(cluster_idx < K, cluster_idx, 0)

        # base ptr for centroid accum array
        centroid_base = cluster_idx * stride_sum_k

        offs = tl.arange(0, BLOCK_D)
        for d_start in range(0, D, BLOCK_D):
            mask = offs + d_start < D
            feats = tl.load(x_tok_ptr + (d_start + offs) * stride_x_d, mask=mask, other=0.0)
            feats = feats.to(tl.float32)
            dest_ptr = sum_ptr + centroid_base + (d_start + offs) * stride_sum_d
            tl.atomic_add(dest_ptr, feats, mask=mask)

        tl.atomic_add(count_ptr + cluster_idx, 1)


    @triton.jit
    def _centroid_update_chunk_kernel(
            x_ptr,  # *f16 / *f32 [N, D] – ORIGINAL ORDER
            sorted_idx_ptr,  # *i32        [N]    – indices after sort
            sorted_cluster_ptr,  # *i32        [N]    – cluster ids in sorted order
            sum_ptr,  # *f32        [K, D]
            count_ptr,  # *i32        [K]
            # strides
            stride_x_n, stride_x_d,
            N: tl.constexpr,
            D: tl.constexpr,
            K: tl.constexpr,
            BLOCK_N: tl.constexpr,
    ):
        """Each program processes BLOCK_N consecutive, already-sorted tokens."""
        pid_chunk = tl.program_id(axis=0)
        chunk_start = pid_chunk * BLOCK_N

        if chunk_start >= N:
            return

        # helper aranges
        offs_token = tl.arange(0, BLOCK_N)
        offs_dim = tl.arange(0, D)

        # token indices & validity mask
        token_idx = chunk_start + offs_token
        valid_tok = token_idx < N
        first_token_idx = chunk_start
        last_token_idx = tl.minimum(chunk_start + BLOCK_N, N) - 1

        # Load cluster ids
        first_id = tl.load(sorted_cluster_ptr + first_token_idx)
        last_id = tl.load(sorted_cluster_ptr + last_token_idx)
        all_ids = tl.load(sorted_cluster_ptr + token_idx, mask=valid_tok, other=-1)

        # Load original indices
        all_tokens_idxs = tl.load(sorted_idx_ptr + token_idx, mask=valid_tok, other=-1)

        for cid in range(first_id, last_id + 1):
            cluster_mask = all_ids == cid
            cluster_size = tl.sum(cluster_mask.to(tl.int32))
            if cluster_size != 0:
                row_ptrs = x_ptr + all_tokens_idxs[:, None] * stride_x_n + offs_dim[None, :] * stride_x_d
                cluster_feats = tl.load(row_ptrs, mask=cluster_mask[:, None], other=0.0)
                cluster_feats = cluster_feats.to(tl.float32)
                sum_feats = tl.sum(cluster_feats, axis=0)
                dest_ptr = sum_ptr + cid * D + offs_dim
                tl.atomic_add(dest_ptr, sum_feats)
                tl.atomic_add(count_ptr + cid, cluster_size)


    def triton_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor,
                                      old_centroids: torch.Tensor):
        """Compute centroids using custom Triton kernel.

        Args:
            x_norm (Tensor): (N, D) normalized input vectors
            cluster_ids (LongTensor): (N,) cluster assignment per point
            old_centroids (Tensor): (K, D) previous centroids

        Returns:
            Tensor: (K, D) updated and L2-normalized centroids
        """
        assert x_norm.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device"
        assert x_norm.ndim == 2 and old_centroids.ndim == 2, "Expected 2D tensors"

        N, D = x_norm.shape
        K, D2 = old_centroids.shape
        assert D == D2, "Feature dimension mismatch"
        assert cluster_ids.shape == (N,)

        # Allocate accumulation buffers
        centroid_sums = torch.zeros((K, D), device=x_norm.device, dtype=torch.float32)
        centroid_counts = torch.zeros(K, device=x_norm.device, dtype=torch.int32)

        BLOCK_D = 128
        grid = (N,)
        _centroid_update_kernel[grid](
            x_norm,
            cluster_ids.to(torch.int32),
            centroid_sums,
            centroid_counts,
            x_norm.stride(0), x_norm.stride(1),
            centroid_sums.stride(0), centroid_sums.stride(1),
            N, D, K,
            BLOCK_D=BLOCK_D,
        )

        # Compute means; keep old centroid if empty cluster
        counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
        centroids = centroid_sums / counts_f

        # For clusters with zero count, revert to old centroids
        zero_mask = (centroid_counts == 0).unsqueeze(-1)
        centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids)

        centroids = centroids.to(x_norm.dtype)
        centroids = F.normalize(centroids, p=2, dim=-1)
        return centroids


    def triton_centroid_update_euclid(x: torch.Tensor, cluster_ids: torch.Tensor,
                                      old_centroids: torch.Tensor):
        """Compute centroids for Euclidean KMeans using Triton.

        Args:
            x (Tensor): (N, D) input vectors
            cluster_ids (LongTensor): (N,) cluster assignment per point
            old_centroids (Tensor): (K, D) previous centroids

        Returns:
            Tensor: (K, D) updated centroids
        """
        assert x.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device"
        assert x.ndim == 2 and old_centroids.ndim == 2, "Expected 2D tensors"

        N, D = x.shape
        K, D2 = old_centroids.shape
        assert D == D2, "Feature dimension mismatch"
        assert cluster_ids.shape == (N,)

        # Allocate accumulation buffers
        centroid_sums = torch.zeros((K, D), device=x.device, dtype=torch.float32)
        centroid_counts = torch.zeros(K, device=x.device, dtype=torch.int32)

        BLOCK_D = 128
        grid = (N,)

        _centroid_update_kernel[grid](
            x,
            cluster_ids.to(torch.int32),
            centroid_sums,
            centroid_counts,
            x.stride(0), x.stride(1),
            centroid_sums.stride(0), centroid_sums.stride(1),
            N, D, K,
            BLOCK_D=BLOCK_D,
        )

        # Compute means; keep old centroid if empty cluster
        counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
        centroids = centroid_sums / counts_f

        # For clusters with zero count, revert to old centroids
        zero_mask = (centroid_counts == 0).unsqueeze(-1)
        centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids)

        return centroids.to(x.dtype)


    def triton_centroid_update_sorted_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor,
                                             old_centroids: torch.Tensor, *, BLOCK_N: int = 256):
        """Fast centroid update assuming cluster_ids are sorted along N.

        Args:
            x_norm (Tensor): (N, D) normalized input vectors
            cluster_ids (LongTensor): (N,) cluster assignment per point
            old_centroids (Tensor): (K, D) previous centroids
            BLOCK_N (int): Tokens per Triton program

        Returns:
            Tensor: (K, D) updated and L2-normalized centroids
        """
        assert x_norm.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA"
        assert x_norm.ndim == 2 and old_centroids.ndim == 2, "Expected 2D tensors"

        N, D = x_norm.shape
        K, D2 = old_centroids.shape
        assert D == D2, "Feature dimension mismatch"
        assert cluster_ids.shape == (N,)

        # Sort per-batch
        sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids)
        sorted_idx_int = sorted_idx.to(torch.int32)

        # accumulation buffers
        centroid_sums = torch.zeros((K, D), device=x_norm.device, dtype=torch.float32)
        centroid_cnts = torch.zeros(K, device=x_norm.device, dtype=torch.int32)

        grid = (triton.cdiv(N, BLOCK_N),)
        _centroid_update_chunk_kernel[grid](
            x_norm,
            sorted_idx_int,
            sorted_cluster_ids.to(torch.int32),
            centroid_sums,
            centroid_cnts,
            x_norm.stride(0), x_norm.stride(1),
            N, D, K,
            BLOCK_N=BLOCK_N,
        )

        # finalise
        counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
        centroids = centroid_sums / counts_f
        empty_mask = (centroid_cnts == 0).unsqueeze(-1)
        centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids)
        centroids = centroids.to(x_norm.dtype)
        centroids = F.normalize(centroids, p=2, dim=-1)
        return centroids


    def triton_centroid_update_sorted_euclid(x: torch.Tensor, cluster_ids: torch.Tensor,
                                             old_centroids: torch.Tensor, *, BLOCK_N: int = 256,
                                             centroid_sums: torch.Tensor = None,
                                             centroid_cnts: torch.Tensor = None,
                                             calculate_new: bool = True):
        """Fast centroid update for Euclidean KMeans assuming cluster IDs are pre-sorted.

        Args:
            x (Tensor): (N, D) input feature vectors
            cluster_ids (LongTensor): (N,) cluster assignment
            old_centroids (Tensor): (K, D) previous centroids
            BLOCK_N (int): Tokens per Triton program
            centroid_sums (Tensor): (K, D) pre-allocated accumulation buffer (optional)
            centroid_cnts (Tensor): (K,) pre-allocated count buffer (optional)
            calculate_new (bool): Whether to compute and return new centroids

        Returns:
            Tensor: (K, D) updated centroids or None if calculate_new=False
        """
        assert x.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA device"
        assert x.ndim == 2 and old_centroids.ndim == 2, "Expected 2D tensors"

        N, D = x.shape
        K, D2 = old_centroids.shape
        assert D == D2, "Feature dimension mismatch"

        # Sort cluster assignments
        sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids)
        sorted_idx_int = sorted_idx.to(torch.int32)

        if centroid_sums is None:
            centroid_sums = torch.zeros((K, D), device=x.device, dtype=torch.float32)
        else:
            assert centroid_sums.shape == (K, D)

        if centroid_cnts is None:
            centroid_cnts = torch.zeros(K, device=x.device, dtype=torch.int32)
        else:
            assert centroid_cnts.shape == (K,)

        grid = (triton.cdiv(N, BLOCK_N),)
        _centroid_update_chunk_kernel[grid](
            x,
            sorted_idx_int,
            sorted_cluster_ids.to(torch.int32),
            centroid_sums,
            centroid_cnts,
            x.stride(0), x.stride(1),
            N, D, K,
            BLOCK_N=BLOCK_N,
        )

        if calculate_new:
            counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
            centroids = centroid_sums / counts_f
            empty_mask = (centroid_cnts == 0).unsqueeze(-1)
            centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids)
            return centroids.to(x.dtype)
        else:
            return None


    # -------------------- Single-iteration kernels --------------------

    def _euclid_iter(x, x_sq, centroids):
        cluster_ids = euclid_assign_triton(x, centroids, x_sq)
        centroids_new = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids)
        shift = (centroids_new - centroids).norm(dim=-1).max()
        return centroids_new, shift, cluster_ids


    def _cosine_iter(x_norm, centroids):
        cluster_ids = cosine_assign_triton(x_norm, centroids)
        centroids_new = triton_centroid_update_sorted_cosine(x_norm, cluster_ids, centroids)
        shift = (centroids_new - centroids).norm(dim=-1).max()
        return centroids_new, shift, cluster_ids


    def _dot_iter(x, centroids):
        cluster_ids = cosine_assign_triton(x, centroids)
        centroids_new = triton_centroid_update_sorted_cosine(x, cluster_ids, centroids)
        shift = (centroids_new - centroids).norm(dim=-1).max()
        return centroids_new, shift, cluster_ids


    COMPILE_FLAG = False

    try:
        if COMPILE_FLAG:
            _euclid_iter_compiled = torch.compile(_euclid_iter, dynamic=True, mode="reduce-overhead")
            _cosine_iter_compiled = torch.compile(_cosine_iter, dynamic=True, mode="reduce-overhead")
            _dot_iter_compiled = torch.compile(_dot_iter, dynamic=True, mode="reduce-overhead")
        else:
            _euclid_iter_compiled = _euclid_iter
            _cosine_iter_compiled = _cosine_iter
            _dot_iter_compiled = _dot_iter
    except Exception:
        _euclid_iter_compiled = _euclid_iter
        _cosine_iter_compiled = _cosine_iter
        _dot_iter_compiled = _dot_iter


    def kmeans_Euclid(x, n_clusters, max_iters=100, tol=0.0, init_centroids=None, verbose=False):
        """
        KMeans clustering in PyTorch using Euclidean distance.

        Args:
            x: Tensor of shape (N, D), N points, D dims.
            n_clusters: Number of clusters.
            max_iters: Max number of iterations.
            tol: Tolerance for center movement.
            init_centroids: Initial centroids (K, D) or None
            verbose: Print progress.

        Returns:
            cluster_ids: (N,) LongTensor, cluster assignment for each point.
            centroids: (K, D) final cluster centers.
            n_iters: Number of iterations performed.
        """
        assert x.ndim == 2, "x must be 2D tensor (N, D)"
        N, D = x.shape

        # Pre-compute squared L2 norm of all points
        x_sq = (x ** 2).sum(dim=-1)  # (N,)

        if init_centroids is None:
            # Randomly select initial centers from x
            indices = torch.randint(0, N, (n_clusters,), device=x.device)
            centroids = x[indices]  # (K, D)
        else:
            centroids = init_centroids
            assert centroids.shape == (n_clusters, D), "init_centroids shape mismatch"

        for it in range(max_iters):
            centroids_new, center_shift, cluster_ids = _euclid_iter_compiled(x, x_sq, centroids)

            if verbose:
                print(f"Iter {it}, center shift: {center_shift.item():.6f}")

            if center_shift < tol:
                break

            centroids = centroids_new

        return cluster_ids, centroids, it + 1


    def kmeans_Cosine(x, n_clusters, max_iters=100, tol=0.0, init_centroids=None, verbose=False):
        """
        KMeans clustering in PyTorch using Cosine similarity.

        Args:
            x: Tensor of shape (N, D), N points, D dims.
            n_clusters: Number of clusters.
            max_iters: Max number of iterations.
            tol: Tolerance for center movement.
            init_centroids: Initial centroids (K, D) or None
            verbose: Print progress.

        Returns:
            cluster_ids: (N,) LongTensor, cluster assignment for each point.
            centroids: (K, D) final cluster centers.
            n_iters: Number of iterations performed.
        """
        assert x.ndim == 2, "x must be 2D tensor (N, D)"
        N, D = x.shape

        # Normalize input vectors for cosine similarity
        x_norm = F.normalize(x, p=2, dim=-1)  # (N, D)

        if init_centroids is None:
            # Randomly select initial centers from x_norm
            indices = torch.randint(0, N, (n_clusters,), device=x.device)
            centroids = x_norm[indices]  # (K, D)
        else:
            centroids = init_centroids
            assert centroids.shape == (n_clusters, D), "init_centroids shape mismatch"

        centroids = F.normalize(centroids, p=2, dim=-1)  # Ensure centroids are normalized

        for it in range(max_iters):
            centroids_new, center_shift, cluster_ids = _cosine_iter_compiled(x_norm, centroids)

            if verbose:
                print(f"Iter {it}, center shift: {center_shift.item():.6f}")

            if center_shift < tol:
                break

            centroids = centroids_new

        return cluster_ids, centroids, it + 1


    def kmeans_Dot(x, n_clusters, max_iters=100, tol=0.0, init_centroids=None, verbose=False):
        """
        KMeans clustering in PyTorch using raw dot-product as similarity.

        Args:
            x: Tensor of shape (N, D), N points, D dims.
            n_clusters: Number of clusters.
            max_iters: Max number of iterations.
            tol: Tolerance for center movement.
            init_centroids: Initial centroids (K, D) or None
            verbose: Print progress.

        Returns:
            cluster_ids: (N,) LongTensor, cluster assignment for each point.
            centroids: (K, D) final cluster centers.
            n_iters: Number of iterations performed.
        """
        assert x.ndim == 2, "x must be 2D tensor (N, D)"
        N, D = x.shape

        if init_centroids is None:
            indices = torch.randint(0, N, (n_clusters,), device=x.device)
            centroids = x[indices]
        else:
            centroids = init_centroids
            assert centroids.shape == (n_clusters, D), "init_centroids shape mismatch"

        for it in range(max_iters):
            centroids_new, center_shift, cluster_ids = _dot_iter_compiled(x, centroids)

            if verbose:
                print(f"Iter {it} (dot), center shift: {center_shift.item():.6f}")

            if center_shift < tol:
                break

            centroids = centroids_new

        return cluster_ids, centroids, it + 1


    def _require_cuda():
        """Check if CUDA is available."""
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is required to run the Triton-backed k-means implementation.")


    class TritonKMeans:
        """
        Fast K-Means clustering implemented with Triton GPU kernels.

        This implementation provides an interface compatible with TorchKMeans
        while leveraging Triton kernels for improved performance.

        Parameters
        ----------
        metric : str, default='euclidean'
            Distance metric to use. Options: 'euclidean', 'cosine', 'dot'
        init : str or torch.Tensor, default='k-means++'
            Method for initialization: 'k-means++', 'random' or user-specified
            tensor of shape (n_clusters, n_features).
        random_state : int, optional
            Random seed for centroid initialization.
        n_clusters : int, default=8
            Number of clusters (k).
        n_init : int, default=10
            Number of times the algorithm will be run with different centroid seeds.
            The final result will be the best output of n_init consecutive runs.
        max_iter : int, default=300
            Maximum number of iterations for a single run.
        tol : float, default=1e-4
            Relative tolerance with regards to inertia to declare convergence.
        verbose : bool, default=False
            Whether to print per-iteration info.
        dtype : torch.dtype, optional
            Compute data type for algorithm.
        device : torch.device, optional
            Target device. Defaults to "cuda:0" when available.
            Currently, only CUDA devices are supported.
        distributed : bool, default=False
            Reserved for future distributed training support (currently not implemented).
        """

        def __init__(
                self,
                metric: str = 'euclidean',
                init: Union[str, torch.Tensor] = 'k-means++',
                random_state: Optional[int] = None,
                n_clusters: int = 8,
                n_init: int = 10,
                max_iter: int = 300,
                tol: float = 1e-4,
                verbose: bool = False,
                dtype: Optional[torch.dtype] = None,
                device: Optional[torch.device] = None,
                distributed: bool = False,
        ):
            _require_cuda()

            self.metric = metric.lower()
            if self.metric not in ['euclidean', 'cosine', 'dot']:
                raise ValueError(
                    f'Invalid metric value. Must be either "euclidean", "cosine" or "dot". '
                    f'But got "{metric}".'
                )

            # Set distance function based on metric
            from pyagc.clusters.torch_kmeans import _pairwise_euclidean, _pairwise_cosine, _pairwise_dot
            self.distance_metric = {
                'euclidean': _pairwise_euclidean,
                'cosine': _pairwise_cosine,
                'dot': _pairwise_dot
            }[self.metric]

            self.n_clusters = int(n_clusters)
            self.n_init = int(n_init)
            self.max_iter = int(max_iter)
            self.tol = float(tol)
            self.verbose = bool(verbose)
            self.dtype = dtype
            self.init = init

            if isinstance(self.init, torch.Tensor):
                self.n_init = 1

            if random_state is None:
                random_state = 0
            self.random_state = int(random_state)

            # Device setup
            if device is None:
                self.device = torch.device("cuda:0")
            else:
                if device.type != "cuda":
                    raise ValueError("Only CUDA devices are supported.")
                self.device = device

            # Model state
            self.cluster_centers_: Optional[torch.Tensor] = None
            self.labels_: Optional[torch.Tensor] = None
            self.inertia_: Optional[float] = None

            # Statistics from all runs
            self.stats = {'state': [], 'inertia': [], 'label': []}

            # Distributed training (reserved for future use)
            self.distributed = distributed
            if self.distributed:
                raise NotImplementedError("Distributed training is not yet supported for TritonKMeans.")

            # Backward compatibility attributes
            self.d = None  # Will be set during fit
            self.k = self.n_clusters
            self.niter = self.max_iter
            self.seed = self.random_state
            self.centroids = None  # Alias for cluster_centers_
            self.cluster_ids = None  # Alias for labels_

        @torch.no_grad()
        def initialize(self, X: torch.Tensor, random_state: int) -> torch.Tensor:
            """
            Initializes the cluster centers.

            Parameters
            ----------
            X : torch.Tensor
                The input data of shape (n_samples, n_features).
            random_state : int
                The random seed.

            Returns
            -------
            torch.Tensor
                Initialized cluster centers of shape (n_clusters, n_features).
            """
            num_samples = X.size(0)

            if isinstance(self.init, str):
                generator = torch.Generator(device=str(X.device)).manual_seed(random_state)

                if self.init == 'random':
                    indices = torch.randperm(num_samples, generator=generator, device=X.device)[:self.n_clusters]
                    init_state = X[indices].clone()
                elif self.init == 'k-means++':
                    from pyagc.clusters.torch_kmeans import _kmeans_plusplus
                    init_state, _ = _kmeans_plusplus(
                        X,
                        n_clusters=self.n_clusters,
                        random_state=random_state,
                        pairwise_distance=self.distance_metric
                    )
                else:
                    raise NotImplementedError(f"Unknown init method: {self.init}")
            elif isinstance(self.init, torch.Tensor):
                init_state = self.init.to(device=X.device, dtype=X.dtype)
                assert init_state.shape == (self.n_clusters, X.shape[1]), \
                    f"init shape mismatch: expected ({self.n_clusters}, {X.shape[1]}), got {init_state.shape}"
            else:
                raise NotImplementedError(f"Unsupported init type: {type(self.init)}")

            return init_state

        @torch.no_grad()
        def _single_iteration(self, x: torch.Tensor, centroids: torch.Tensor) -> Tuple[torch.Tensor, float, torch.Tensor]:
            """
            Performs a single k-means iteration: assignment + update.

            Parameters
            ----------
            x : torch.Tensor
                Input data of shape (N, D).
            centroids : torch.Tensor
                Current centroids of shape (K, D).

            Returns
            -------
            Tuple[torch.Tensor, float, torch.Tensor]
                (new_centroids, center_shift, cluster_ids)
            """
            if self.metric == 'euclidean':
                x_sq = (x ** 2).sum(dim=-1)
                new_centroids, center_shift, cluster_ids = _euclid_iter_compiled(x, x_sq, centroids)
            elif self.metric == 'cosine':
                # x is l2-normalized
                centroids = F.normalize(centroids, p=2, dim=-1)
                new_centroids, center_shift, cluster_ids = _cosine_iter_compiled(x, centroids)
            elif self.metric == 'dot':
                new_centroids, center_shift, cluster_ids = _dot_iter_compiled(x, centroids)
            else:
                raise ValueError(f"Unsupported metric: {self.metric}")

            return new_centroids, center_shift.item(), cluster_ids

        @torch.no_grad()
        def _compute_inertia(self, X: torch.Tensor, labels: torch.Tensor, centroids: torch.Tensor) -> float:
            """
            Computes the sum of squared distances of samples to their closest cluster center.
            Optimized version using vectorized operations.

            Parameters
            ----------
            X : torch.Tensor
                Input data of shape (n_samples, n_features).
            labels : torch.Tensor
                Cluster assignments of shape (n_samples,).
            centroids : torch.Tensor
                Cluster centers of shape (n_clusters, n_features).

            Returns
            -------
            float
                Total inertia.
            """
            # Get assigned centroids: (n_samples, n_features)
            assigned_centroids = centroids[labels]

            # Compute pairwise=False distances (element-wise comparison)
            dists = self.distance_metric(X, assigned_centroids, pairwise=False)

            return dists.sum().item()

        @torch.no_grad()
        def fit_predict(self, X: torch.Tensor) -> torch.Tensor:
            """
            Performs k-means clustering on the input data and returns cluster labels.
            Optimized version without inertia computation during training.

            Parameters
            ----------
            X : torch.Tensor
                The input data of shape (n_samples, n_features).

            Returns
            -------
            torch.Tensor
                Cluster assignments of shape (n_samples,).
            """
            if X.ndim != 2:
                raise ValueError("X must be of shape (n_samples, n_features)")

            N, D = X.shape
            self.d = D  # Set feature dimensionality

            # Prepare data
            compute_dtype = self.dtype or X.dtype
            X = X.to(device=self.device, dtype=compute_dtype, copy=False)
            if self.metric == 'cosine':
                X = F.normalize(X, p=2, dim=-1)

            # Compute tolerance
            tol = torch.mean(torch.var(X, dim=0)).item() * self.tol

            min_shift = float('inf')  # Track minimum center shift instead of inertia
            best_centroids = None
            best_labels = None

            # Reset stats (optional: can be removed if not needed)
            self.stats = {'state': [], 'shift': [], 'label': []}

            # Multiple random initializations
            for n_init_idx in range(self.n_init):
                random_state = self.random_state + n_init_idx

                # Initialize centroids
                centroids = self.initialize(X, random_state=random_state)

                old_labels = None
                final_shift = float('inf')

                # Progress bar for this run
                progress_bar = tqdm.tqdm(total=self.max_iter, disable=not self.verbose)

                for n_iter in range(self.max_iter):
                    # Single iteration
                    new_centroids, center_shift, labels = self._single_iteration(X, centroids)

                    # Update progress
                    if self.verbose:
                        progress_bar.set_description(
                            f'n_init {n_init_idx + 1}/{self.n_init}, '
                            f'iter {n_iter}, shift {center_shift:.6f}'
                        )
                        progress_bar.update(1)

                    # Check for convergence
                    if old_labels is not None and torch.equal(labels, old_labels):
                        if self.verbose:
                            print(f"\nConverged at iteration {n_iter}: strict convergence.")
                        final_shift = center_shift
                        break
                    elif center_shift <= tol:
                        if self.verbose:
                            print(f"\nConverged at iteration {n_iter}: "
                                  f"center shift {center_shift:.2e} within tolerance {tol:.2e}.")
                        final_shift = center_shift
                        break

                    old_labels = labels.clone()
                    centroids = new_centroids
                    final_shift = center_shift

                progress_bar.close()

                # Store stats (using final shift instead of inertia)
                self.stats['state'].append(centroids)
                self.stats['shift'].append(final_shift)
                self.stats['label'].append(labels)

                # Track best result based on final center shift
                if final_shift < min_shift:
                    min_shift = final_shift
                    best_centroids = centroids
                    best_labels = labels

            # Convert stats to tensors
            self.stats['state'] = torch.stack(self.stats['state'])
            self.stats['shift'] = torch.tensor(self.stats['shift'])
            self.stats['label'] = torch.stack(self.stats['label'])

            if self.verbose:
                print(f"Final min center shift: {min_shift:.6f}")

            # Store final results
            self.cluster_centers_ = best_centroids
            self.labels_ = best_labels.long()
            self.inertia_ = None  # Set to None since we don't compute it during training

            # Set backward compatibility aliases
            self.centroids = self.cluster_centers_
            self.cluster_ids = self.labels_

            return self.labels_

        @torch.no_grad()
        def fit(self, X: torch.Tensor):
            """
            Fit k-means clustering on the input data.

            Alias for fit_predict that returns self for sklearn-style chaining.

            Parameters
            ----------
            X : torch.Tensor
                Input data of shape (n_samples, n_features).

            Returns
            -------
            self
                Fitted estimator.
            """
            self.fit_predict(X)
            return self

        @torch.no_grad()
        def train(self, X: torch.Tensor):
            """
            Fit k-means clustering on the input data.

            Backward compatibility method - same as fit().

            Parameters
            ----------
            X : torch.Tensor
                Input data of shape (n_samples, n_features).
            """
            self.fit_predict(X)

        @torch.no_grad()
        def predict(self, X: torch.Tensor, soft: bool = False) -> torch.Tensor:
            """
            Assigns samples to clusters based on fixed cluster centers.

            Parameters
            ----------
            X : torch.Tensor
                Input tensor of shape (n_samples, n_features).
            soft : bool, default=False
                If True, returns the soft assignment matrix (probabilities);
                if False, returns hard cluster assignments (indices).

            Returns
            -------
            torch.Tensor
                - If soft=False: (n_samples,) tensor of cluster indices.
                - If soft=True: (n_samples, n_clusters) tensor of probabilities.
            """
            if self.cluster_centers_ is None:
                raise RuntimeError("Model not trained. Call fit() or fit_predict() first.")

            if X.ndim != 2:
                raise ValueError("X must be of shape (n_samples, n_features)")

            N, D = X.shape
            if D != self.d:
                raise ValueError(f"Feature dimension mismatch: expected {self.d}, got {D}")

            # Prepare data
            compute_dtype = self.dtype or X.dtype
            X = X.to(device=self.device, dtype=compute_dtype, copy=False)

            dists = self.distance_metric(X, self.cluster_centers_)  # (n_samples, n_clusters)

            if soft:
                # Convert distances to probabilities
                # Smaller distance => higher probability
                return (-dists.sqrt()).softmax(dim=-1)
            else:
                # Hard assignment: return nearest cluster index
                return dists.argmin(dim=-1)

        @torch.no_grad()
        def transform(self, X: torch.Tensor) -> torch.Tensor:
            """
            Transform data to cluster-distance space.

            Parameters
            ----------
            X : torch.Tensor
                Shape: (n_samples, n_features)

            Returns
            -------
            torch.Tensor
                Distance to each cluster center. Shape: (n_samples, n_clusters)
            """
            if self.cluster_centers_ is None:
                raise RuntimeError("Model not trained. Call fit() or fit_predict() first.")

            if X.ndim != 2:
                raise ValueError("X must be of shape (n_samples, n_features)")

            N, D = X.shape
            if D != self.d:
                raise ValueError(f"Feature dimension mismatch: expected {self.d}, got {D}")

            compute_dtype = self.dtype or X.dtype
            X = X.to(device=self.device, dtype=compute_dtype, copy=False)

            # Compute distances in chunks
            split_size = min(4096, X.size(0))
            all_dists = []

            for chunk in X.split(split_size, dim=0):
                dists = self.distance_metric(chunk, self.cluster_centers_)
                all_dists.append(dists)

            return torch.cat(all_dists, dim=0)

        @torch.no_grad()
        def fit_transform(self, X: torch.Tensor) -> torch.Tensor:
            """
            Fit k-means clustering and transform X to cluster-distance space.

            Parameters
            ----------
            X : torch.Tensor
                Input data of shape (n_samples, n_features).

            Returns
            -------
            torch.Tensor
                Distance to each cluster center. Shape: (n_samples, n_clusters)
            """
            self.fit_predict(X)
            return self.transform(X)

        @torch.no_grad()
        def score(self, X: torch.Tensor) -> float:
            """
            Compute the opposite of the value of X on the K-means objective.

            This method computes inertia on-demand, not during training.

            Parameters
            ----------
            X : torch.Tensor
                Input data of shape (n_samples, n_features).

            Returns
            -------
            float
                Opposite of the sum of squared distances of samples to their
                closest cluster center (negative inertia).
            """
            labels = self.predict(X, soft=False)
            inertia = self._compute_inertia(X, labels, self.cluster_centers_)
            return -inertia

        @torch.no_grad()
        def compute_inertia(self, X: torch.Tensor = None) -> float:
            """
            Compute inertia on-demand after training.

            Parameters
            ----------
            X : torch.Tensor, optional
                Input data. If None, uses the training labels stored in self.labels_.

            Returns
            -------
            float
                The inertia value.
            """
            if self.cluster_centers_ is None:
                raise RuntimeError("Model not trained. Call fit() or fit_predict() first.")

            if X is not None:
                labels = self.predict(X, soft=False)
            else:
                if self.labels_ is None:
                    raise RuntimeError("No training labels available. Provide X explicitly.")
                # Need to recompute using stored training data
                # Since we don't store X, user must provide it
                raise ValueError("X must be provided to compute inertia after training.")

            return self._compute_inertia(X, labels, self.cluster_centers_)

        def __repr__(self) -> str:
            """String representation of the TritonKMeans object."""
            return (
                f"TritonKMeans(metric={self.metric!r}, "
                f"init={self.init!r}, "
                f"n_clusters={self.n_clusters}, "
                f"n_init={self.n_init}, "
                f"max_iter={self.max_iter}, "
                f"tol={self.tol}, "
                f"random_state={self.random_state}, "
                f"verbose={self.verbose})"
            )
# ============================================================
# Fallback stub class when Triton is not available.
# Ensures Sphinx documentation builds and imports succeed
# without requiring Triton to be installed.
# ============================================================
else:
[docs] class TritonKMeans: """ Fast K-Means clustering implemented with Triton GPU kernels. This implementation provides an interface compatible with TorchKMeans while leveraging Triton kernels for improved performance. Parameters ---------- metric : str, default='euclidean' Distance metric to use. Options: 'euclidean', 'cosine', 'dot' init : str or torch.Tensor, default='k-means++' Method for initialization: 'k-means++', 'random' or user-specified tensor of shape (n_clusters, n_features). random_state : int, optional Random seed for centroid initialization. n_clusters : int, default=8 Number of clusters (k). n_init : int, default=10 Number of times the algorithm will be run with different centroid seeds. The final result will be the best output of n_init consecutive runs. max_iter : int, default=300 Maximum number of iterations for a single run. tol : float, default=1e-4 Relative tolerance with regards to inertia to declare convergence. verbose : bool, default=False Whether to print per-iteration info. dtype : torch.dtype, optional Compute data type for algorithm. device : torch.device, optional Target device. Defaults to "cuda:0" when available. Currently, only CUDA devices are supported. distributed : bool, default=False Reserved for future distributed training support (currently not implemented). """
[docs] def __init__(self, *args, **kwargs): raise ImportError( "TritonKMeans requires the 'triton' package and a CUDA-capable GPU.\n" "Install it with:\n" " pip install triton\n" " # or\n" " pip install pyagc[triton]\n\n" "Alternatively, use TorchKMeans for CPU/CUDA compatibility." )
[docs] def fit(self, X): """Fit k-means clustering.""" raise ImportError("triton is not installed.")
[docs] def fit_predict(self, X): """Fit and predict cluster labels.""" raise ImportError("triton is not installed.")
[docs] def predict(self, X, soft=False): """Predict cluster labels.""" raise ImportError("triton is not installed.")
[docs] def transform(self, X): """Transform to cluster-distance space.""" raise ImportError("triton is not installed.")
[docs] def fit_transform(self, X): """Fit and transform.""" raise ImportError("triton is not installed.")
def __repr__(self) -> str: return "TritonKMeans(triton not installed)"