from typing import Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.utils import to_torch_coo_tensor
from pyagc.clusters import BaseClusterHead
EPS = 1e-15
def _mk_smart_teleportation_flow(
A: Tensor, alpha: float = 0.15, n_iters: int = 100, device: str = "cpu"
) -> Tuple[Tensor, Tensor]:
r"""
Construct the smart teleportation flow matrix F and the stationary
node visit probabilities p as described in the Neuromap paper.
Args:
A (torch.Tensor): Adjacency matrix of shape (N, N)
alpha (float, optional): Teleportation probability. Default: 0.15
n_iters (int, optional): Number of power iterations. Default: 100
device (str, optional): Device for computation. Default: "cpu"
Returns:
- F (torch.Tensor): Flow matrix of shape (N, N)
- p (torch.Tensor): Stationary node visit probabilities of shape (N,)
"""
# --- Build transition matrix T ---
T = torch.nan_to_num(
A.T * (torch.sum(A, dim=1) ** (-1.0)).to_dense(), nan=0.0
).T.to(device)
# --- Distribution according to in-degrees ---
e_v = (torch.sum(A, dim=0) / torch.sum(A)).to_dense().to(device)
# --- Power iteration for stationary distribution ---
p = e_v
for _ in range(n_iters):
p = alpha * e_v + (1 - alpha) * (p @ T)
# --- Smart teleportation flow matrix ---
F = alpha * A / torch.sum(A) + (1 - alpha) * (p * T.T).T
return F, p
def _mk_smart_teleportation_flow_sparse(
A: torch.sparse_coo_tensor, alpha: float = 0.15, n_iters: int = 100
) -> Tuple[Tensor, Tensor]:
r"""
Construct the smart teleportation flow matrix F (sparse)
and stationary node visit probabilities p.
Args:
A (torch.sparse_coo_tensor): Sparse adjacency matrix of shape (N, N)
alpha (float, optional): Teleportation probability. Default: 0.15
n_iters (int, optional): Number of power iterations. Default: 100
Returns:
- F (torch.sparse_coo_tensor): Sparse flow matrix
- p (torch.Tensor): Stationary node visit probabilities (N,)
"""
assert A.is_sparse, "A must be torch.sparse_coo_tensor"
device = A.device
A = A.coalesce()
# --- Compute out-degree for each node ---
row_sum = torch.sparse.sum(A, dim=1).to_dense()
row_inv = torch.nan_to_num(row_sum.pow(-1), nan=0.0)
## --- Build transition matrix T = D^{-1}A ---
T = torch.sparse_coo_tensor(
A.indices(),
A.values() * row_inv[A.indices()[0]],
size=A.shape,
device=device
).coalesce() # must be coalesced before accessing values()
# --- Teleportation distribution e_v based on in-degree ---
e_v = torch.sparse.sum(A, dim=0).to_dense()
e_v = e_v / e_v.sum()
# --- Power iteration for stationary distribution p ---
p = e_v.clone()
for _ in range(n_iters):
p_new = alpha * e_v + (1 - alpha) * torch.sparse.mm(T.T, p.unsqueeze(1)).squeeze(1)
if torch.allclose(p_new, p, rtol=1e-6, atol=1e-9):
break
p = p_new
# --- Smart teleportation flow matrix ---
total_A = torch.sparse.sum(A)
p_values = p[A.indices()[0]]
F_values = alpha * A.values() / total_A + (1 - alpha) * p_values * T.values()
F = torch.sparse_coo_tensor(
A.indices(), F_values, size=A.shape, device=device
).coalesce()
return F, p
[docs]class NeuromapClusterHead(BaseClusterHead):
r"""
Neuromap Clustering Head from the paper
`"The Map Equation Goes Neural: Mapping Network Flows with Graph Neural Networks"
<https://arxiv.org/abs/2310.01144>`_ paper (Blöcker et al., NeurIPS 2024).
This module implements a differentiable version of the map equation
for end-to-end optimization with (graph) neural networks.
It learns soft cluster assignments :math:`\mathbf{S}` via a linear projection
from node embeddings :math:`\mathbf{Z}`, and computes the Neuromap loss
(expected per-step description length) following the Minimum Description Length principle:
.. math::
\mathcal{L}(A, S) = q \log q
- \sum_m q_m \log q_m
- \sum_m m_{\text{exit}} \log m_{\text{exit}}
- \sum_u p_u \log p_u
+ \sum_m p_m \log p_m
where all quantities are computed from the soft cluster assignment matrix.
Args:
n_clusters (int): Maximum number of clusters.
n_features (int): Feature dimension of input node embeddings.
alpha (float, optional): Teleportation probability for flow. Default: 0.15.
n_iters (int, optional): Number of power iterations for stationary distribution. Default: 100.
"""
[docs] def __init__(
self,
n_clusters: int,
n_features: int,
alpha: float = 0.15,
n_iters: int = 100,
):
super(NeuromapClusterHead, self).__init__()
self.n_clusters = n_clusters
self.n_features = n_features
self.alpha = alpha
self.n_iters = n_iters
# Cluster centers are learnable parameters.
self.cluster_centers = nn.Parameter(torch.empty(n_clusters, n_features))
self.reset_cluster_centers()
# Cached flow matrix and p may be set externally (lazy initialization)
self.F = None
self.p = None
self.p_log_p = None
[docs] def reset_cluster_centers(self, cluster_centers: Tensor = None) -> None:
r"""
Manually sets the cluster centers.
Args:
cluster_centers (torch.Tensor, optional):
Tensor of shape :obj:`(n_clusters, n_features)` to initialize
the cluster centers. If None, use Xavier uniform initialization.
"""
if cluster_centers is not None:
assert cluster_centers.shape == (self.n_clusters, self.n_features)
with torch.no_grad():
self.cluster_centers.copy_(cluster_centers)
else:
nn.init.xavier_uniform_(self.cluster_centers)
[docs] def build_flow(self, edge_index: Tensor, N: int):
"""Construct sparse flow matrix F and stationary distribution p."""
device = edge_index.device
A = to_torch_coo_tensor(edge_index, size=(N, N)).to(device)
self.F, self.p = _mk_smart_teleportation_flow_sparse(A, alpha=self.alpha, n_iters=self.n_iters)
self.p_log_p = torch.sum(self.p * torch.nan_to_num(torch.log2(self.p), nan=0.0))
[docs] def forward0(self, z: Tensor, edge_index: Tensor) -> Tensor:
r"""
Compute the Neuromap (Map Equation) loss given node embeddings and adjacency.
Args:
z (torch.Tensor): Node embeddings of shape :obj:`(N, F)`.
edge_index (torch.Tensor): Edge indices of shape :obj:`(2, E)`.
Returns:
torch.Tensor: Map equation loss (codelength)
"""
N = z.size(0)
# If flow not built yet, initialize it
if self.F is None or self.p is None:
self.build_flow(edge_index, N)
# === Step 1: Compute soft assignments ===
sim = torch.matmul(z, self.cluster_centers.t()) # (N, K)
S = torch.softmax(sim, dim=-1)
# === Step 2: Pool flow to community level (sparse ops) ===
# (1) F @ S
FS = torch.sparse.mm(self.F, S) # (N, K)
# (2) Sᵀ(FS)
C = torch.matmul(S.T, FS) # (K, K)
diag_C = torch.diag(C)
# === Step 3: Compute map equation quantities ===
q = 1.0 - torch.trace(C)
q_m = torch.sum(C, dim=1) - diag_C
m_exit = torch.sum(C, dim=0) - diag_C
p_m = q_m + torch.sum(C, dim=0)
# === Step 4: Map equation codelength ===
codelength = (
torch.sum(q * torch.nan_to_num(torch.log2(q), nan=0.0))
- torch.sum(q_m * torch.nan_to_num(torch.log2(q_m), nan=0.0))
- torch.sum(m_exit * torch.nan_to_num(torch.log2(m_exit), nan=0.0))
- self.p_log_p
+ torch.sum(p_m * torch.nan_to_num(torch.log2(p_m), nan=0.0))
)
return codelength
[docs] def forward(self, z: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Compute the Neuromap (Map Equation) loss given node embeddings and adjacency.
Args:
z (torch.Tensor): Node embeddings of shape :obj:`(N, F)`.
edge_index (torch.Tensor): Edge indices of shape :obj:`(2, E)`.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Map equation loss (codelength) and collapse_loss
"""
N = z.size(0)
# If flow not built yet, initialize it
if self.F is None or self.p is None:
self.build_flow(edge_index, N)
# === Step 1: Compute soft assignments ===
sim = torch.matmul(z, self.cluster_centers.t()) # (N, K)
S = torch.softmax(sim, dim=-1)
# === Step 2: Pool flow to community level (sparse ops) ===
# (1) F @ S
FS = torch.sparse.mm(self.F, S) # (N, K)
# (2) Sᵀ(FS)
C = torch.matmul(S.T, FS) # (K, K)
diag_C = torch.diag(C)
# === Step 3: Compute map equation quantities ===
q = 1.0 - torch.trace(C)
q_m = torch.sum(C, dim=1) - diag_C
m_exit = torch.sum(C, dim=0) - diag_C
p_m = q_m + torch.sum(C, dim=0)
# === Step 4: Map equation codelength ===
codelength = (
torch.sum(q * torch.nan_to_num(torch.log2(q), nan=0.0))
- torch.sum(q_m * torch.nan_to_num(torch.log2(q_m), nan=0.0))
- torch.sum(m_exit * torch.nan_to_num(torch.log2(m_exit), nan=0.0))
- self.p_log_p
+ torch.sum(p_m * torch.nan_to_num(torch.log2(p_m), nan=0.0))
)
cluster_sizes = S.sum(dim=0) # (K,)
collapse_loss = (cluster_sizes.norm() * (self.n_clusters ** 0.5) / N) - 1
return codelength, collapse_loss
[docs] @torch.no_grad()
def cluster(self, z: Tensor, soft: bool = False) -> Tensor:
r"""
Predicts cluster assignments.
Args:
z (torch.Tensor): Input tensor of shape :obj:`(n_samples, n_features)`.
soft (bool, optional):
If True, returns the soft assignment matrix;
if False, returns hard cluster assignments. (default: :obj:`False`)
Returns:
- If :obj:`soft` is False, :obj:`(n_samples,)` tensor of cluster indices.
- If :obj:`soft` is True, :obj:`(n_samples, n_clusters)` tensor of probabilities.
"""
sim = torch.matmul(z, self.cluster_centers.t())
if soft:
return sim.softmax(dim=-1)
else:
return sim.argmax(dim=-1)