from typing import Tuple, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.utils import degree, negative_sampling
from pyagc.clusters import BaseClusterHead
EPS = 1e-15
[docs]class SBMClusterHead(BaseClusterHead):
r"""
Stochastic Block Model (SBM) Clustering Head from the paper
`"Differentiable Community Detection with Graph Neural Networks and Stochastic Block Models"
<https://openreview.net/forum?id=T1vdfm1THf>`_ (Arliss & Mueller, LoG 2025).
This head learns cluster assignments by maximizing the likelihood of an SBM-based generative model.
It supports both Bernoulli and Poisson variants, with optional degree correction.
The cluster assignment matrix :math:`\mathbf{P} \in [0,1]^{N \times K}` is obtained via softmax
transformation of the similarity between node embeddings and learnable cluster centers, and the
structure matrix :math:`\mathbf{\Theta} \in \mathbb{R}^{K \times K}` is estimated via MLE as:
.. math::
\hat{\Theta}_{ij} = \frac{M_{ij}}{n_i n_j}
where :math:`M_{ij}` is the number of edges between communities :math:`i` and :math:`j`,
and :math:`n_i` is the number of nodes in community :math:`i`.
**Loss Functions:**
**(1) Bernoulli SBM:**
.. math::
\mathcal{L}_B = -\sum_{(u,v) \in E} \ln(\pi_{uv}) - \eta^{-1} \sum_{(u,v) \notin E} \ln(1 - \pi_{uv})
where :math:`\pi_{uv} = \mathbf{P}_u^T \hat{\Theta} \mathbf{P}_v`.
**(2) Poisson SBM:**
.. math::
\mathcal{L}_P = -\sum_{(u,v) \in E} [\ln(\pi_{uv}) - \pi_{uv}] + \eta^{-1} \sum_{(u,v) \notin E} \pi_{uv}
**(3) Degree-Corrected variants:**
For degree correction, the expected value becomes :math:`\phi_u \phi_v \mathbf{P}_u^T \hat{\Theta} \mathbf{P}_v`,
where:
.. math::
\hat{\phi}_u = (\mathbf{P}_u^T \mathbf{n}) \frac{d_u}{\mathbf{P}_u^T \boldsymbol{\delta}}
with :math:`\boldsymbol{\delta}` being the sum of degrees in each community.
Args:
n_clusters (int): Number of clusters.
n_features (int): Feature dimension of input node embeddings.
variant (str, optional): SBM variant to use. Options: :obj:`'bernoulli'`, :obj:`'poisson'`,
:obj:`'bernoulli-dc'`, :obj:`'poisson-dc'`. (default: :obj:`'bernoulli'`)
eta (float, optional): Negative sampling ratio (number of negative samples per positive edge).
(default: :obj:`3.0`)
"""
[docs] def __init__(
self,
n_clusters: int,
n_features: int,
variant: str = 'bernoulli',
eta: float = 3.0,
):
super().__init__()
if variant not in ('bernoulli', 'poisson', 'bernoulli-dc', 'poisson-dc'):
raise ValueError(f"Invalid variant: '{variant}'. Expected one of: "
"'bernoulli', 'poisson', 'bernoulli-dc', 'poisson-dc'")
self.n_clusters = n_clusters
self.n_features = n_features
self.variant = variant
self.eta = eta
self.degree_corrected = variant.endswith('-dc')
# Cluster centers are learnable parameters.
self.cluster_centers = nn.Parameter(torch.empty(n_clusters, n_features))
self.reset_cluster_centers()
[docs] def reset_cluster_centers(self, cluster_centers: Optional[Tensor] = None) -> None:
r"""
Manually sets the cluster centers.
Args:
cluster_centers (torch.Tensor, optional):
Tensor of shape :obj:`(n_clusters, n_features)` to initialize
the cluster centers. If None, use Xavier uniform initialization.
"""
if cluster_centers is not None:
assert cluster_centers.shape == (self.n_clusters, self.n_features)
with torch.no_grad():
self.cluster_centers.copy_(cluster_centers)
else:
nn.init.xavier_uniform_(self.cluster_centers)
def _estimate_structure_matrix(
self,
P: Tensor,
edge_index: Tensor
) -> Tensor:
r"""
Estimates the structure matrix :math:`\hat{\Theta}` using MLE.
Args:
P (torch.Tensor): Soft partition matrix of shape :obj:`(N, K)`.
edge_index (torch.Tensor): Edge indices of shape :obj:`(2, E)`.
Returns:
Structure matrix of shape :obj:`(K, K)`.
"""
# Community sizes: n = sum_u P_u (K,)
n = P.sum(dim=0).clamp(min=EPS) # (K,)
# Edge count matrix: M_ij = sum_{(u,v) in E} P_ui * P_vj
# Efficient implementation: M = P[src].T @ P[dst]
src, dst = edge_index
M = P[src].T @ P[dst] # (K, K)
# MLE: Theta_ij = M_ij / (n_i * n_j)
Theta = M / (torch.outer(n, n).clamp(min=EPS))
return Theta
def _estimate_degree_correction(
self,
P: Tensor,
edge_index: Tensor
) -> Tensor:
r"""
Estimates the degree correction vector :math:`\hat{\phi}` using MLE.
Args:
P (torch.Tensor): Soft partition matrix of shape :obj:`(N, K)`.
edge_index (torch.Tensor): Edge indices of shape :obj:`(2, E)`.
Returns:
Degree correction vector of shape :obj:`(N,)`.
"""
N = P.size(0)
device = P.device
# Node degrees
d = degree(edge_index[0], N, dtype=P.dtype).view(-1, 1) # (N, 1)
# Community sizes
n = P.sum(dim=0).clamp(min=EPS) # (K,)
# Sum of degrees per community: delta_i = sum_u P_ui * d_u
delta = (P.T @ d).squeeze(-1).clamp(min=EPS) # (K,)
# phi_u = (P_u^T n) * d_u / (P_u^T delta)
P_n = (P @ n.unsqueeze(-1)).squeeze(-1) # (N,)
P_delta = (P @ delta.unsqueeze(-1)).squeeze(-1).clamp(min=EPS) # (N,)
phi = P_n * d.squeeze(-1) / P_delta
# For Bernoulli variant, clamp to [0, 1]
if 'bernoulli' in self.variant:
phi = phi.clamp(max=1.0)
return phi
[docs] def forward(
self,
z: Tensor,
edge_index: Tensor,
num_neg_samples: Optional[int] = None
) -> Tuple[Tensor, Tensor]:
r"""
Computes the SBM loss.
Args:
z (torch.Tensor): Node embeddings of shape :obj:`(N, F)`.
edge_index (torch.Tensor): Edge indices of shape :obj:`(2, E)`.
num_neg_samples (int, optional): Number of negative samples. If None,
uses :obj:`eta * num_edges`. (default: :obj:`None`)
Returns:
Tuple of (likelihood_loss, regularization_loss).
"""
N = z.size(0)
device = z.device
# Compute soft partition via similarity to cluster centers
sim = torch.matmul(z, self.cluster_centers.t()) # (N, K)
P = torch.softmax(sim, dim=-1) # (N, K)
# Estimate structure matrix
Theta = self._estimate_structure_matrix(P, edge_index)
# Estimate degree correction if needed
phi = None
if self.degree_corrected:
phi = self._estimate_degree_correction(P, edge_index)
# Number of positive and negative samples
num_pos = edge_index.size(1)
if num_neg_samples is None:
num_neg_samples = int(self.eta * num_pos)
# Sample negative edges
neg_edge_index = negative_sampling(
edge_index, num_nodes=N, num_neg_samples=num_neg_samples
)
# Compute edge probabilities for positive and negative edges
src_pos, dst_pos = edge_index
src_neg, dst_neg = neg_edge_index
# Efficient probability computation for sampled edges only
if phi is not None:
pi_pos = (P[src_pos] * (Theta @ P[dst_pos].T).T).sum(dim=1)
pi_pos = phi[src_pos] * pi_pos * phi[dst_pos]
pi_neg = (P[src_neg] * (Theta @ P[dst_neg].T).T).sum(dim=1)
pi_neg = phi[src_neg] * pi_neg * phi[dst_neg]
else:
pi_pos = (P[src_pos] * (Theta @ P[dst_pos].T).T).sum(dim=1)
pi_neg = (P[src_neg] * (Theta @ P[dst_neg].T).T).sum(dim=1)
pi_pos = pi_pos.clamp(min=EPS, max=1.0 - EPS)
pi_neg = pi_neg.clamp(min=EPS, max=1.0 - EPS)
# Compute loss based on variant
if self.variant.startswith('bernoulli'):
loss_pos = -torch.log(pi_pos).mean()
loss_neg = -torch.log(1 - pi_neg).mean()
else: # poisson
loss_pos = -(torch.log(pi_pos) - pi_pos).mean()
loss_neg = pi_neg.mean()
likelihood_loss = loss_pos + loss_neg / self.eta
# Regularization: encourage diagonal dominance
reg_loss = torch.norm(1.0 - Theta.diag())
return likelihood_loss, reg_loss
[docs] @torch.no_grad()
def cluster(self, z: Tensor, soft: bool = False) -> Tensor:
r"""
Predicts cluster assignments.
Args:
z (torch.Tensor): Node embeddings of shape :obj:`(N, F)`.
soft (bool, optional): If True, returns soft assignments; otherwise hard assignments.
(default: :obj:`False`)
Returns:
Cluster assignments.
"""
sim = torch.matmul(z, self.cluster_centers.t())
if soft:
return sim.softmax(dim=-1)
else:
return sim.argmax(dim=-1)
[docs]class SBMMatchClusterHead(BaseClusterHead):
r"""
Graph Matching SBM Clustering Head from the paper
`"Differentiable Community Detection with Graph Neural Networks and Stochastic Block Models"
<https://openreview.net/forum?id=T1vdfm1THf>`_ (Arliss & Mueller, LoG 2025).
This variant uses the Graph Matching objective, which aligns the graph with its
community representation by minimizing:
.. math::
\mathcal{L}_{Match} = -\mathrm{tr}(\mathbf{A}^T \mathbf{P} \hat{\Theta} \mathbf{P}^T)
This approach exploits matrix sparsity and is significantly faster than edge sampling methods.
Args:
n_clusters (int): Number of clusters.
n_features (int): Feature dimension of input node embeddings.
"""
[docs] def __init__(self, n_clusters: int, n_features: int):
super().__init__()
self.n_clusters = n_clusters
self.n_features = n_features
# Cluster centers are learnable parameters.
self.cluster_centers = nn.Parameter(torch.empty(n_clusters, n_features))
self.reset_cluster_centers()
[docs] def reset_cluster_centers(self, cluster_centers: Optional[Tensor] = None) -> None:
r"""
Manually sets the cluster centers.
Args:
cluster_centers (torch.Tensor, optional):
Tensor of shape :obj:`(n_clusters, n_features)` to initialize
the cluster centers. If None, use Xavier uniform initialization.
"""
if cluster_centers is not None:
assert cluster_centers.shape == (self.n_clusters, self.n_features)
with torch.no_grad():
self.cluster_centers.copy_(cluster_centers)
else:
nn.init.xavier_uniform_(self.cluster_centers)
def _estimate_structure_matrix(
self,
P: Tensor,
edge_index: Tensor
) -> Tensor:
r"""Estimates the structure matrix using MLE."""
n = P.sum(dim=0).clamp(min=EPS)
src, dst = edge_index
M = P[src].T @ P[dst] # (K, K)
Theta = M / (torch.outer(n, n).clamp(min=EPS))
return Theta
[docs] def forward(self, z: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Computes the Graph Matching loss.
Args:
z (torch.Tensor): Node embeddings of shape :obj:`(N, F)`.
edge_index (torch.Tensor): Edge indices of shape :obj:`(2, E)`.
Returns:
Tuple of (likelihood_loss, regularization_loss).
"""
# Compute soft partition via similarity to cluster centers
sim = torch.matmul(z, self.cluster_centers.t()) # (N, K)
P = torch.softmax(sim, dim=-1) # (N, K)
# Estimate structure matrix
Theta = self._estimate_structure_matrix(P, edge_index)
# Compute A^T @ P efficiently using sparse operations
# A^T @ P means for each node u, sum P_v over all v that connect to u
src, dst = edge_index
A_T_P = torch.zeros_like(P)
A_T_P.index_add_(0, dst, P[src])
# Compute trace(A^T @ P @ Theta @ P^T) efficiently
# = trace(P^T @ A^T @ P @ Theta)
# = sum_i sum_j (P^T @ A^T @ P)_ij * Theta_ji
P_T_A_T_P = P.T @ A_T_P # (K, K)
likelihood_loss = -torch.trace(P_T_A_T_P @ Theta) / edge_index.size(1)
# Regularization: encourage diagonal dominance
reg_loss = torch.norm(1.0 - Theta.diag())
return likelihood_loss, reg_loss
[docs] @torch.no_grad()
def cluster(self, z: Tensor, soft: bool = False) -> Tensor:
r"""
Predicts cluster assignments.
Args:
z (torch.Tensor): Node embeddings of shape :obj:`(N, F)`.
soft (bool, optional): If True, returns soft assignments; otherwise hard assignments.
(default: :obj:`False`)
Returns:
Cluster assignments.
"""
sim = torch.matmul(z, self.cluster_centers.t())
if soft:
return sim.softmax(dim=-1)
else:
return sim.argmax(dim=-1)