Source code for pyagc.clusters.base_cluster_head

from abc import abstractmethod, ABC
from typing import Any

import torch
from torch import Tensor
import torch.nn as nn


[docs]class BaseClusterHead(nn.Module, ABC): r""" Base class for clustering heads in neural clustering models. """
[docs] def __init__(self): super(BaseClusterHead, self).__init__()
[docs] @torch.no_grad() @abstractmethod def cluster(self, *args: Any, **kwargs: Any) -> Tensor: r""" Predicts cluster assignments. Returns: - If soft=False, :obj:`(n_samples,)` tensor of cluster indices. - If soft=True, :obj:`(n_samples, n_clusters)` tensor of probabilities. """ pass
@property def predict(self): r""" Alias for :meth:`cluster`. """ return self.cluster
[docs] @abstractmethod def forward(self, *args: Any, **kwargs: Any) -> Tensor: r"""Runs the forward pass of the module.""" pass