Source code for pyagc.metrics.label_metrics

import warnings

import numpy as np
import torch
from sklearn.metrics import (
    normalized_mutual_info_score,
    adjusted_rand_score,
    homogeneity_score,
    completeness_score,
    accuracy_score,
    f1_score,
)
from scipy.optimize import linear_sum_assignment
from typing import Union, Tuple, Dict


def _to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    elif isinstance(x, np.ndarray):
        return x
    return np.asarray(x)


def _hungarian_match(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    r"""Align the predicted clusters with the true labels using the Hungarian algorithm.

    This function uses the Hungarian algorithm to find the optimal assignment
    of predicted clusters to true labels based on maximizing the overlap between
    the true labels and predicted cluster assignments.

    Args:
        y_true (np.ndarray): True labels of shape :obj:`(n_samples,)`.
        y_pred (np.ndarray): Predicted cluster labels of shape :obj:`(n_samples,)`.

    Returns:
        Aligned predicted labels of shape :obj:`(n_samples,)`.
    """
    true_labels = np.unique(y_true)
    pred_labels = np.unique(y_pred)

    if len(true_labels) != len(pred_labels):
        warnings.warn(
            f"Number of predicted clusters ({len(pred_labels)}) differs from "
            f"true labels ({len(true_labels)}). Results may be unreliable."
        )

    # Create cost matrix
    cost_matrix = np.zeros((len(true_labels), len(pred_labels)))
    for i, t in enumerate(true_labels):
        for j, p in enumerate(pred_labels):
            cost_matrix[i, j] = np.sum((y_true == t) & (y_pred == p))

    # Use the Hungarian algorithm (linear sum assignment) to find the optimal label mapping
    row_ind, col_ind = linear_sum_assignment(-cost_matrix)

    # Map predicted label values to true label values
    label_map = {pred_labels[j]: true_labels[i] for i, j in zip(row_ind, col_ind)}

    # Apply mapping safely
    y_pred_aligned = np.array([label_map.get(label, label) for label in y_pred])
    return y_pred_aligned


VALID_METRICS = ('NMI', 'ARI', 'Homo', 'Comp', 'ACC', 'F1')


[docs]def label_metrics( y_true: Union[torch.Tensor, np.ndarray], y_pred: Union[torch.Tensor, np.ndarray], metrics: Union[str, Tuple[str, ...]] = ('NMI', 'ARI', 'ACC', 'F1') ) -> Dict[str, float]: r"""Compute clustering evaluation metrics. If accuracy or Macro-F1 score is requested, it performs alignment of predicted clusters with true labels using the Hungarian algorithm to account for label mismatches. Args: y_true (torch.Tensor or np.ndarray): True labels of shape :obj:`(n_samples,)`. y_pred (torch.Tensor or np.ndarray): Predicted cluster labels of shape :obj:`(n_samples,)`. metrics (str or tuple of str, optional): The metrics to compute. Can be one or more of :obj:`('NMI', 'ARI', 'Homo', 'Comp', 'ACC', 'F1')`. Default is :obj:`('NMI', 'ARI', 'ACC', 'F1')`. Returns: Dictionary mapping metric names to their computed values. Example: >>> result = label_metrics(y_true, y_pred, metrics=('NMI', 'ARI', 'ACC')) >>> print(result) {'NMI': 0.85, 'ARI': 0.72, 'ACC': 0.89} """ # Validate metrics argument if isinstance(metrics, str): metrics = (metrics,) invalid_metrics = [metric for metric in metrics if metric not in VALID_METRICS] if invalid_metrics: raise ValueError( f"Invalid metric(s): {', '.join(invalid_metrics)}. " f"Valid metrics are: {', '.join(VALID_METRICS)}.") # Convert torch tensors to numpy arrays if needed y_true = _to_numpy(y_true) y_pred = _to_numpy(y_pred) # Initialize results dictionary results = {} # If accuracy or Macro-F1 is needed, perform alignment only once if 'ACC' in metrics or 'F1' in metrics: y_pred = _hungarian_match(y_true, y_pred) # Compute selected metrics if 'NMI' in metrics: results['NMI'] = normalized_mutual_info_score(y_true, y_pred) if 'ARI' in metrics: results['ARI'] = adjusted_rand_score(y_true, y_pred) if 'Homo' in metrics: results['Homo'] = homogeneity_score(y_true, y_pred) if 'Comp' in metrics: results['Comp'] = completeness_score(y_true, y_pred) if 'ACC' in metrics: results['ACC'] = accuracy_score(y_true, y_pred) if 'F1' in metrics: results['F1'] = f1_score(y_true, y_pred, average='macro') return results