Source code for pyagc.models.neuromap

from typing import Any

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset

from pyagc.clusters import NeuromapClusterHead
from pyagc.models.base import ClusteringModel, LossOutput
from pyagc.utils import filter_kwargs


[docs]class Neuromap(ClusteringModel): r""" The **Neuromap** model implements the differentiable map equation from `"The Map Equation Goes Neural: Mapping Network Flows with Graph Neural Networks" <https://arxiv.org/abs/2310.01144>`_ (Blöcker et al., NeurIPS 2024). This model performs **unsupervised graph clustering** by combining a graph encoder (e.g., GCN, GraphSAGE) with the :class:`~pyagc.clusters.NeuromapClusterHead`. The encoder produces node embeddings :math:`\mathbf{Z}`, which are projected into soft cluster assignments :math:`\mathbf{S}`. The model then minimizes the differentiable **Map Equation loss**, which measures the expected description length of random walks on the graph according to the Minimum Description Length (MDL) principle: .. math:: \mathcal{L}(A, S) = q \log q - \sum_m q_m \log q_m - \sum_m m_{\text{exit}} \log m_{\text{exit}} - \sum_u p_u \log p_u + \sum_m p_m \log p_m where all terms are derived from the soft community flow structure induced by :math:`S`. This loss naturally balances model complexity and data fit without explicit regularization, allowing Neuromap to infer the number of effective communities automatically. Args: encoder (torch.nn.Module): Node encoder producing node embeddings. n_features (int): Feature dimension of encoder outputs. n_clusters (int): Maximum number of clusters to consider. lam (float, optional): Regularization coefficient for the collapse loss. (default: :obj:`1.0`) alpha (float, optional): Teleportation probability. (default: :obj:`0.15`) n_iters (int, optional): Power iteration steps for stationary distribution. (default: :obj:`100`) """ def __init__(self, encoder: nn.Module, n_features: int, n_clusters: int, lam: float = 1.0, alpha: float = 0.15, n_iters: int = 100): super().__init__() self.encoder = encoder self.n_features = n_features self.n_clusters = n_clusters self.lam = lam self.alpha = alpha self.n_iters = n_iters # Neuromap clustering head self.head = NeuromapClusterHead( n_clusters=n_clusters, n_features=n_features, alpha=alpha, n_iters=n_iters, )
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the encoder and the cluster head.""" reset(self.encoder) self.head.reset_cluster_centers() self.head.F = None self.head.p = None self.head.p_log_p = None
[docs] def embed(self, *args, **kwargs) -> Tensor: r"""Compute node embeddings via the encoder.""" return self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def forward(self, *args, **kwargs) -> Tensor: r"""Predicts hard cluster assignments from current parameters.""" z = self.embed(*args, **kwargs) return self.head.cluster(z)
[docs] def loss(self, x: Tensor, edge_index: Tensor, **kwargs: Any) -> LossOutput: r""" Computes the Neuromap loss with multiple components. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices. Returns: LossOutput containing total loss and individual components. """ z = self.embed(x, edge_index, **kwargs) codelength, collapse_loss = self.head(z=z, edge_index=edge_index) loss = codelength + self.lam * collapse_loss return LossOutput( total=loss, components={ 'codelength': codelength.item(), 'collapse': collapse_loss.item() } )
[docs] def loss_batch(self, batch: Data, **kwargs: Any): r"""Neuromap currently does not support mini-batch training.""" raise NotImplementedError(f"{self.__class__.__name__} does not support batch training.")
def __repr__(self): return f"{self.__class__.__name__}(encoder={self.encoder.__class__.__name__}, n_clusters={self.n_clusters})"