Source code for pyagc.models.gbt

from typing import Tuple

import torch
from torch import Tensor, nn
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset
from torch_geometric.transforms import BaseTransform

from pyagc.models.base import TrainableModel, LossOutput
from pyagc.utils import off_diagonal, filter_kwargs


[docs]class GBT(TrainableModel): r""" The G-BT (Graph Barlow Twins) model is proposed in the `"Graph Barlow Twins: A Self-Supervised Representation Learning Framework for Graphs" <https://arxiv.org/abs/2106.02466>`_ paper (Bielak et al., KBS 2022). G-BT learns node representations by maximizing the similarity between two embeddings of the same node under different augmentations, while reducing redundancy between dimensions via a decorrelation objective. The loss function is: .. math:: \mathcal{L} = \sum_i (1 - \mathcal{C}_{ii})^2 + \lambda \sum_{i}\sum_{j \ne i} \mathcal{C}_{ij}^2 where: - :math:`\mathcal{C}` is the cross-correlation matrix between normalized embeddings from two views: :math:`\mathcal{C}_{ij} = \sum_b \tilde{z}^{(1)}_{b,i} \tilde{z}^{(2)}_{b,j}`. - :math:`\tilde{z}` denotes the batch-normalized node embedding. - :math:`\lambda` balances invariance and redundancy reduction. Args: encoder (torch.nn.Module): The shared encoder used to produce node embeddings. transform1 (torch_geometric.transforms.BaseTransform): The first graph view transformation. transform2 (torch_geometric.transforms.BaseTransform): The second graph view transformation. lam (float): Weight for redundancy reduction term. (default: :obj:`5e-3`) """ def __init__(self, encoder: nn.Module, transform1: BaseTransform, transform2: BaseTransform, lam: float = 5e-3): super().__init__() self.encoder = encoder self.transform1 = transform1 self.transform2 = transform2 self.lam = lam
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder)
[docs] def embed(self, *args, **kwargs) -> Tensor: r"""Computes node embeddings.""" return self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: r"""Generates embeddings from two graph augmentations.""" data1 = self.transform1(*args, **kwargs) data2 = self.transform2(*args, **kwargs) z1 = self.encoder(**data1) z2 = self.encoder(**data2) return z1, z2
def _compute_loss(self, z1: Tensor, z2: Tensor) -> Tuple[Tensor, Tensor, Tensor]: r""" Computes the Graph Barlow Twins loss. Args: z1 (torch.Tensor): First view embeddings. z2 (torch.Tensor): Second view embeddings. Returns: Tuple of (total_loss, invariance_term, redundancy_term). """ z1 = (z1 - z1.mean(0)) / z1.std(0) z2 = (z2 - z2.mean(0)) / z2.std(0) c = (z1.T @ z2) / z1.shape[0] # dim x dim inv = torch.diagonal(c).add_(-1).pow_(2).sum() # invariance loss red = off_diagonal(c).pow_(2).sum() # redundancy reduction loss loss = inv + self.lam * red return loss, inv, red
[docs] def loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> LossOutput: r""" Computes the GBT loss with multiple components. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices. Returns: LossOutput containing total loss and individual components. """ z1, z2 = self(x, edge_index, **kwargs) loss, inv, red = self._compute_loss(z1, z2) return LossOutput( total=loss, components={ 'inv': inv.item(), 'red': red.item() } )
[docs] def loss_batch(self, batch: Data) -> LossOutput: r""" Computes loss for a mini-batch with seed node slicing. Args: batch (Data): A mini-batch from the loader. Returns: LossOutput containing total loss and individual components. """ z1, z2 = self(batch.x, batch.edge_index) z1 = z1[:batch.batch_size] z2 = z2[:batch.batch_size] loss, inv, red = self._compute_loss(z1, z2) return LossOutput( total=loss, components={ 'inv': inv.item(), 'red': red.item() } )
def __repr__(self) -> str: return f"{self.__class__.__name__}(encoder={self.encoder})"