from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset
from torch_geometric.nn.models import InnerProductDecoder
from torch_geometric.utils import negative_sampling
from pyagc.clusters import DECClusterHead
from pyagc.models.base import ClusteringModel, LossOutput
from pyagc.utils import filter_kwargs
EPS = 1e-15
[docs]class DAEGC(ClusteringModel):
r"""Deep Attentional Embedded Graph Clustering model from the
`"Attributed Graph Clustering: A Deep Attentional Embedding Approach"
<https://arxiv.org/abs/1906.06532>`_ paper (Wang et al., IJCAI 2019).
DAEGC jointly optimizes graph embedding and clustering through:
1. **Graph Attentional Autoencoder**: Learns representations by encoding
both structure and content with attention mechanism, then reconstructs
the graph structure via inner product decoder.
2. **Self-training Clustering**: Uses confident cluster assignments as soft
labels to guide the optimization, iteratively refining clustering results.
The total loss combines reconstruction and clustering objectives:
.. math::
\mathcal{L} = \mathcal{L}_r + \gamma \mathcal{L}_c
where:
- :math:`\mathcal{L}_r` is the binary cross-entropy reconstruction loss
- :math:`\mathcal{L}_c = KL(P||Q)` is the clustering loss with Student's t-distribution
- :math:`\gamma` is the clustering coefficient
Args:
encoder (torch.nn.Module): The graph attention encoder (typically GAT-based).
decoder (torch.nn.Module, optional): The decoder module. If set to
:obj:`None`, will default to :class:`InnerProductDecoder`.
(default: :obj:`None`)
n_clusters (int): Number of clusters.
hidden_channels (int): Hidden dimension of node embeddings.
gamma (float, optional): Weight for clustering loss. (default: :obj:`10.0`)
update_interval (int, optional): Number of iterations between target
distribution updates. (default: :obj:`5`)
"""
def __init__(
self,
encoder: nn.Module,
n_clusters: int,
hidden_channels: int,
decoder: Optional[nn.Module] = None,
gamma: float = 10.0,
update_interval: int = 5,
):
super().__init__()
self.encoder = encoder
self.decoder = InnerProductDecoder() if decoder is None else decoder
self.n_clusters = n_clusters
self.hidden_channels = hidden_channels
self.gamma = gamma
self.update_interval = update_interval
# Initialize cluster head (DEC-style)
self.cluster_head = DECClusterHead(n_clusters, hidden_channels)
# Track training iterations for target distribution updates
self.register_buffer('iteration', torch.tensor(0))
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
reset(self.encoder)
reset(self.decoder)
self.cluster_head.reset_cluster_centers()
[docs] def embed(self, x: Tensor, edge_index: Tensor, **kwargs) -> Tensor:
r"""Computes node embeddings via the encoder."""
return self.encoder(x, edge_index, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def decode(self, z: Tensor, edge_index: Tensor, sigmoid: bool = True) -> Tensor:
r"""Reconstructs edge probabilities via the decoder."""
return self.decoder(z, edge_index, sigmoid=sigmoid)
[docs] def forward(self, x: Tensor, edge_index: Tensor, **kwargs) -> Tensor:
r"""Returns cluster assignments (hard labels)."""
z = self.embed(x, edge_index, **kwargs)
return self.cluster_head.cluster(z, soft=False)
[docs] def recon_loss(self, z: Tensor, pos_edge_index: Tensor,
neg_edge_index: Optional[Tensor] = None) -> Tensor:
r"""Computes the binary cross-entropy reconstruction loss.
Given latent variables :obj:`z`, computes the binary cross
entropy loss for positive edges :obj:`pos_edge_index` and negative
sampled edges.
Args:
z (torch.Tensor): The latent space :math:`\mathbf{Z}`.
pos_edge_index (torch.Tensor): The positive edges to train against.
neg_edge_index (torch.Tensor, optional): The negative edges to
train against. If not given, uses negative sampling to
calculate negative edges. (default: :obj:`None`)
"""
pos_loss = -torch.log(
self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean()
if neg_edge_index is None:
neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
neg_loss = -torch.log(1 -
self.decoder(z, neg_edge_index, sigmoid=True) +
EPS).mean()
return pos_loss + neg_loss
[docs] def pretrain_loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> Tensor:
r"""Computes pretraining loss (reconstruction only).
Args:
x (torch.Tensor): Node features.
edge_index (torch.Tensor): Edge indices.
Returns:
Reconstruction loss.
"""
z = self.embed(x, edge_index, **kwargs)
return self.recon_loss(z, edge_index)
[docs] def loss(
self,
x: Tensor,
edge_index: Tensor,
pretrain: bool = False,
**kwargs
) -> LossOutput:
r"""Computes the total loss.
Args:
x (torch.Tensor): Node features.
edge_index (torch.Tensor): Edge indices.
pretrain (bool, optional): If :obj:`True`, only compute reconstruction
loss for pretraining. (default: :obj:`False`)
Returns:
LossOutput containing total loss and individual components.
"""
z = self.embed(x, edge_index, **kwargs)
# Reconstruction loss
loss_recon = self.recon_loss(z, edge_index)
if pretrain:
return LossOutput(
total=loss_recon,
components={'recon': loss_recon.item()}
)
# Determine whether to update target distribution P
# Update every `update_interval` iterations to maintain stability
update_target = (self.iteration % self.update_interval == 0)
# Clustering loss (KL divergence between Q and P)
loss_cluster = self.cluster_head(z, update_target=update_target)
# Update iteration counter
if self.training:
self.iteration += 1
total_loss = loss_recon + self.gamma * loss_cluster
return LossOutput(
total=total_loss,
components={
'recon': loss_recon.item(),
'cluster': loss_cluster.item()
}
)
[docs] def loss_batch(self, batch: Data, pretrain: bool = False) -> LossOutput:
r"""Computes loss for a mini-batch.
Args:
batch (Data): A mini-batch from the loader.
pretrain (bool, optional): If :obj:`True`, only compute reconstruction
loss for pretraining.
Returns:
LossOutput containing total loss and individual components.
"""
batch_size = batch.batch_size
# Get embeddings for all nodes in the batch (including neighbors)
z = self.embed(batch.x, batch.edge_index)
# Slice to seed nodes only
z_seed = z[:batch_size]
# Extract edges within the batch
edge_mask = (batch.edge_index[0] < batch_size) & (batch.edge_index[1] < batch_size)
batch_edge_index = batch.edge_index[:, edge_mask]
# Reconstruction loss
loss_recon = self.recon_loss(z_seed, batch_edge_index)
if pretrain:
return LossOutput(
total=loss_recon,
components={'recon': loss_recon.item()}
)
# Determine whether to update target distribution
update_target = (self.iteration % self.update_interval == 0)
# Clustering loss
loss_cluster = self.cluster_head(z_seed, update_target=update_target)
# Increment iteration counter
if self.training:
self.iteration += 1
total_loss = loss_recon + self.gamma * loss_cluster
return LossOutput(
total=total_loss,
components={
'recon': loss_recon.item(),
'cluster': loss_cluster.item()
}
)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'n_clusters={self.n_clusters}, '
f'gamma={self.gamma})')