pyagc.clusters.SBMClusterHead
- class SBMClusterHead(n_clusters: int, n_features: int, variant: str = 'bernoulli', eta: float = 3.0)[source]
Bases:
BaseClusterHeadStochastic Block Model (SBM) Clustering Head from the paper “Differentiable Community Detection with Graph Neural Networks and Stochastic Block Models” (Arliss & Mueller, LoG 2025).
This head learns cluster assignments by maximizing the likelihood of an SBM-based generative model. It supports both Bernoulli and Poisson variants, with optional degree correction.
The cluster assignment matrix \(\mathbf{P} \in [0,1]^{N \times K}\) is obtained via softmax transformation of the similarity between node embeddings and learnable cluster centers, and the structure matrix \(\mathbf{\Theta} \in \mathbb{R}^{K \times K}\) is estimated via MLE as:
\[\hat{\Theta}_{ij} = \frac{M_{ij}}{n_i n_j}\]where \(M_{ij}\) is the number of edges between communities \(i\) and \(j\), and \(n_i\) is the number of nodes in community \(i\).
Loss Functions:
(1) Bernoulli SBM:
\[\mathcal{L}_B = -\sum_{(u,v) \in E} \ln(\pi_{uv}) - \eta^{-1} \sum_{(u,v) \notin E} \ln(1 - \pi_{uv})\]where \(\pi_{uv} = \mathbf{P}_u^T \hat{\Theta} \mathbf{P}_v\).
(2) Poisson SBM:
\[\mathcal{L}_P = -\sum_{(u,v) \in E} [\ln(\pi_{uv}) - \pi_{uv}] + \eta^{-1} \sum_{(u,v) \notin E} \pi_{uv}\](3) Degree-Corrected variants:
For degree correction, the expected value becomes \(\phi_u \phi_v \mathbf{P}_u^T \hat{\Theta} \mathbf{P}_v\), where:
\[\hat{\phi}_u = (\mathbf{P}_u^T \mathbf{n}) \frac{d_u}{\mathbf{P}_u^T \boldsymbol{\delta}}\]with \(\boldsymbol{\delta}\) being the sum of degrees in each community.
- Parameters:
n_clusters (int) – Number of clusters.
n_features (int) – Feature dimension of input node embeddings.
variant (str, optional) – SBM variant to use. Options:
'bernoulli','poisson','bernoulli-dc','poisson-dc'. (default:'bernoulli')eta (float, optional) – Negative sampling ratio (number of negative samples per positive edge). (default:
3.0)
- __init__(n_clusters: int, n_features: int, variant: str = 'bernoulli', eta: float = 3.0)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(n_clusters, n_features[, variant, eta])Initialize internal Module state, shared by both nn.Module and ScriptModule.
add_module(name, module)Add a child module to the current module.
apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
children()Return an iterator over immediate children modules.
cluster(z[, soft])Predicts cluster assignments.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().cpu()Move all model parameters and buffers to the CPU.
cuda([device])Move all model parameters and buffers to the GPU.
double()Casts all floating point parameters and buffers to
doubledatatype.eval()Set the module in evaluation mode.
extra_repr()Return the extra representation of the module.
float()Casts all floating point parameters and buffers to
floatdatatype.forward(z, edge_index[, num_neg_samples])Computes the SBM loss.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()Casts all floating point parameters and buffers to
halfdatatype.ipu([device])Move all model parameters and buffers to the IPU.
load_state_dict(state_dict[, strict, assign])Copy parameters and buffers from
state_dictinto this module and its descendants.modules()Return an iterator over all modules in the network.
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters([recurse])Return an iterator over module parameters.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
reset_cluster_centers([cluster_centers])Manually sets the cluster centers.
set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module[, strict])Set the submodule given by
targetif it exists, otherwise throw an error.share_memory()state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
to(*args, **kwargs)Move and/or cast the parameters and buffers.
to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
train([mode])Set the module in training mode.
type(dst_type)Casts all parameters and buffers to
dst_type.xpu([device])Move all model parameters and buffers to the XPU.
zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
T_destinationcall_super_initdump_patchespredictAlias for
cluster().- reset_cluster_centers(cluster_centers: Optional[Tensor] = None) None[source]
Manually sets the cluster centers.
- Parameters:
cluster_centers (torch.Tensor, optional) – Tensor of shape
(n_clusters, n_features)to initialize the cluster centers. If None, use Xavier uniform initialization.- Return type:
- forward(z: Tensor, edge_index: Tensor, num_neg_samples: Optional[int] = None) Tuple[Tensor, Tensor][source]
Computes the SBM loss.
- Parameters:
z (torch.Tensor) – Node embeddings of shape
(N, F).edge_index (torch.Tensor) – Edge indices of shape
(2, E).num_neg_samples (int, optional) – Number of negative samples. If None, uses
eta * num_edges. (default:None)
- Returns:
Tuple[Tensor,Tensor] – Tuple of (likelihood_loss, regularization_loss).