Source code for pyagc.models.autoencoder

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Data
from torch_geometric.nn.inits import reset
from torch_geometric.nn.models import InnerProductDecoder
from torch_geometric.utils import negative_sampling

from pyagc.models.base import TrainableModel, LossOutput
from pyagc.utils import filter_kwargs

EPS = 1e-15
MAX_LOGSTD = 10


[docs]class GAE(TrainableModel): r"""The Graph Auto-Encoder model from the `"Variational Graph Auto-Encoders" <https://arxiv.org/abs/1611.07308>`_ paper based on user-defined encoder and decoder models. Args: encoder (torch.nn.Module): The encoder module. decoder (torch.nn.Module, optional): The decoder module. If set to :obj:`None`, will default to the :class:`torch_geometric.nn.models.InnerProductDecoder`. (default: :obj:`None`) """ def __init__(self, encoder: Module, decoder: Optional[Module] = None): super().__init__() self.encoder = encoder self.decoder = InnerProductDecoder() if decoder is None else decoder self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder) reset(self.decoder)
[docs] def forward(self, *args, **kwargs) -> Tensor: r"""Alias for :meth:`embed`.""" return self.embed(*args, **kwargs)
[docs] def embed(self, *args, **kwargs) -> Tensor: r"""Computes node embeddings via the encoder.""" return self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs))
[docs] def decode(self, *args, **kwargs) -> Tensor: r"""Runs the decoder and computes edge probabilities.""" return self.decoder(*args, **kwargs)
[docs] def recon_loss(self, z: Tensor, pos_edge_index: Tensor, neg_edge_index: Optional[Tensor] = None) -> Tensor: r"""Given latent variables :obj:`z`, computes the binary cross entropy loss for positive edges :obj:`pos_edge_index` and negative sampled edges. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. pos_edge_index (torch.Tensor): The positive edges to train against. neg_edge_index (torch.Tensor, optional): The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default: :obj:`None`) """ pos_loss = -torch.log( self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean() if neg_edge_index is None: neg_edge_index = negative_sampling(pos_edge_index, z.size(0)) neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + EPS).mean() return pos_loss + neg_loss
[docs] def loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> Tensor: r""" Computes the reconstruction loss for GAE. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices (positive edges). Returns: Reconstruction loss as a scalar tensor. """ z = self.embed(x, edge_index, **kwargs) return self.recon_loss(z, edge_index)
[docs] def loss_batch(self, batch: Data) -> Tensor: r""" Computes loss for a mini-batch with seed node slicing. Args: batch (Data): A mini-batch from the loader. Returns: Reconstruction loss as a scalar tensor. """ z = self.embed(batch.x, batch.edge_index) z = z[:batch.batch_size] # Extract edges within the batch batch_mask = (batch.edge_index[0] < batch.batch_size) & (batch.edge_index[1] < batch.batch_size) batch_edge_index = batch.edge_index[:, batch_mask] return self.recon_loss(z, batch_edge_index)
class VGAE(GAE): r"""The Variational Graph Auto-Encoder model from the `"Variational Graph Auto-Encoders" <https://arxiv.org/abs/1611.07308>`_ paper. Args: encoder (torch.nn.Module): The encoder module to compute :math:`\mu` and :math:`\log\sigma^2`. decoder (torch.nn.Module, optional): The decoder module. If set to :obj:`None`, will default to the :class:`torch_geometric.nn.models.InnerProductDecoder`. (default: :obj:`None`) """ def __init__(self, encoder: Module, decoder: Optional[Module] = None): super().__init__(encoder, decoder) self.__mu__ = None self.__logstd__ = None def reparametrize(self, mu: Tensor, logstd: Tensor) -> Tensor: r"""Reparametrization trick for variational inference.""" if self.training: return mu + torch.randn_like(logstd) * torch.exp(logstd) else: return mu def embed(self, *args, **kwargs) -> Tensor: r""" Computes node embeddings via the variational encoder. The encoder outputs both :math:`\mu` and :math:`\log\sigma^2`, which are used for the reparametrization trick. """ self.__mu__, self.__logstd__ = self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs)) self.__logstd__ = self.__logstd__.clamp(max=MAX_LOGSTD) z = self.reparametrize(self.__mu__, self.__logstd__) return z def kl_loss(self, mu: Optional[Tensor] = None, logstd: Optional[Tensor] = None) -> Tensor: r"""Computes the KL loss, either for the passed arguments :obj:`mu` and :obj:`logstd`, or based on latent variables from last encoding. Args: mu (torch.Tensor, optional): The latent space for :math:`\mu`. If set to :obj:`None`, uses the last computation of :math:`\mu`. (default: :obj:`None`) logstd (torch.Tensor, optional): The latent space for :math:`\log\sigma`. If set to :obj:`None`, uses the last computation of :math:`\log\sigma^2`. (default: :obj:`None`) """ mu = self.__mu__ if mu is None else mu logstd = self.__logstd__ if logstd is None else logstd.clamp( max=MAX_LOGSTD) return -0.5 * torch.mean( torch.sum(1 + 2 * logstd - mu ** 2 - logstd.exp() ** 2, dim=1)) def loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> LossOutput: r""" Computes the VGAE loss with reconstruction and KL divergence components. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices (positive edges). Returns: LossOutput containing total loss and individual components. """ z = self.embed(x, edge_index, **kwargs) recon = self.recon_loss(z, edge_index) kl = self.kl_loss() return LossOutput( total=recon + kl, components={ 'recon': recon.item(), 'kl': kl.item() } ) def loss_batch(self, batch: Data) -> LossOutput: r""" Computes loss for a mini-batch with seed node slicing. Args: batch (Data): A mini-batch from the loader. Returns: LossOutput containing total loss and individual components. """ z = self.embed(batch.x, batch.edge_index) z = z[:batch.batch_size] mu = self.__mu__[:batch.batch_size] logstd = self.__logstd__[:batch.batch_size] # Extract edges within the batch batch_mask = (batch.edge_index[0] < batch.batch_size) & (batch.edge_index[1] < batch.batch_size) batch_edge_index = batch.edge_index[:, batch_mask] recon = self.recon_loss(z, batch_edge_index) kl = self.kl_loss(mu, logstd) return LossOutput( total=recon + kl, components={ 'recon': recon.item(), 'kl': kl.item() } )
[docs]class ARGA(GAE): r"""The Adversarially Regularized Graph Auto-Encoder model from the `"Adversarially Regularized Graph Autoencoder for Graph Embedding" <https://arxiv.org/abs/1802.04407>`_ paper. .. note:: ARGA requires a two-phase training procedure (encoder + discriminator). Use :meth:`train_encoder` and :meth:`train_discriminator` separately, or implement a custom training loop. Args: encoder (torch.nn.Module): The encoder module. discriminator (torch.nn.Module): The discriminator module. decoder (torch.nn.Module, optional): The decoder module. If set to :obj:`None`, will default to the :class:`torch_geometric.nn.models.InnerProductDecoder`. (default: :obj:`None`) """ def __init__( self, encoder: Module, discriminator: Module, decoder: Optional[Module] = None, ): super().__init__(encoder, decoder) self.discriminator = discriminator self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" super().reset_parameters() reset(self.discriminator)
[docs] def reg_loss(self, z: Tensor) -> Tensor: r"""Computes the regularization loss of the encoder. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. """ real = torch.sigmoid(self.discriminator(z)) real_loss = -torch.log(real + EPS).mean() return real_loss
[docs] def discriminator_loss(self, z: Tensor) -> Tensor: r"""Computes the loss of the discriminator. Args: z (torch.Tensor): The latent space :math:`\mathbf{Z}`. """ real = torch.sigmoid(self.discriminator(torch.randn_like(z))) fake = torch.sigmoid(self.discriminator(z.detach())) real_loss = -torch.log(real + EPS).mean() fake_loss = -torch.log(1 - fake + EPS).mean() return real_loss + fake_loss
[docs] def loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> LossOutput: r""" Computes the ARGA encoder loss with reconstruction and regularization components. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices (positive edges). Returns: LossOutput containing total loss and individual components. """ z = self.embed(x, edge_index, **kwargs) recon = self.recon_loss(z, edge_index) reg = self.reg_loss(z) return LossOutput( total=recon + reg, components={ 'recon': recon.item(), 'reg': reg.item() } )
[docs] def loss_batch(self, batch: Data) -> LossOutput: r""" Computes encoder loss for a mini-batch with seed node slicing. Args: batch (Data): A mini-batch from the loader. Returns: LossOutput containing total loss and individual components. """ z = self.embed(batch.x, batch.edge_index) z = z[:batch.batch_size] # Extract edges within the batch batch_mask = (batch.edge_index[0] < batch.batch_size) & (batch.edge_index[1] < batch.batch_size) batch_edge_index = batch.edge_index[:, batch_mask] recon = self.recon_loss(z, batch_edge_index) reg = self.reg_loss(z) return LossOutput( total=recon + reg, components={ 'recon': recon.item(), 'reg': reg.item() } )
[docs] def train_encoder(self, data: Data, optimizer: torch.optim.Optimizer, epoch: int, verbose: bool = True) -> float: r""" Trains the encoder for one epoch. This is equivalent to :meth:`train_full` but provided for clarity in the two-phase ARGA training procedure. Args: data (Data): The input full graph data. optimizer (torch.optim.Optimizer): The optimizer for encoder parameters. epoch (int): Current epoch number. verbose (bool, optional): If :obj:`True`, prints training progress. (default: :obj:`True`) Returns: Loss value of the epoch. """ return self.train_full(data, optimizer, epoch, verbose)
[docs] def train_discriminator(self, data: Data, optimizer: torch.optim.Optimizer, epoch: int, verbose: bool = True) -> float: r""" Trains the discriminator for one epoch. Args: data (Data): The input full graph data. optimizer (torch.optim.Optimizer): The optimizer for discriminator parameters. epoch (int): Current epoch number. verbose (bool, optional): If :obj:`True`, prints training progress. (default: :obj:`True`) Returns: Discriminator loss value of the epoch. """ self.train() optimizer.zero_grad() z = self.embed(**data) loss = self.discriminator_loss(z) loss.backward() optimizer.step() if verbose: print(f"Epoch: {epoch:02d} Discriminator Loss: {loss.item():.4f}") return float(loss.item())
class ARGVA(ARGA): r"""The Adversarially Regularized Variational Graph Auto-Encoder model from the `"Adversarially Regularized Graph Autoencoder for Graph Embedding" <https://arxiv.org/abs/1802.04407>`_ paper. .. note:: ARGVA requires a two-phase training procedure (encoder + discriminator). Use :meth:`train_encoder` and :meth:`train_discriminator` separately, or implement a custom training loop. Args: encoder (torch.nn.Module): The encoder module to compute :math:`\mu` and :math:`\log\sigma^2`. discriminator (torch.nn.Module): The discriminator module. decoder (torch.nn.Module, optional): The decoder module. If set to :obj:`None`, will default to the :class:`torch_geometric.nn.models.InnerProductDecoder`. (default: :obj:`None`) """ def __init__( self, encoder: Module, discriminator: Module, decoder: Optional[Module] = None, ): # Note: We bypass ARGA's __init__ and call GAE's __init__ directly GAE.__init__(self, encoder, decoder) self.discriminator = discriminator self.__mu__ = None self.__logstd__ = None self.reset_parameters() def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.encoder) reset(self.decoder) reset(self.discriminator) def reparametrize(self, mu: Tensor, logstd: Tensor) -> Tensor: r"""Reparametrization trick for variational inference.""" if self.training: return mu + torch.randn_like(logstd) * torch.exp(logstd) else: return mu def embed(self, *args, **kwargs) -> Tensor: r""" Computes node embeddings via the variational encoder. The encoder outputs both :math:`\mu` and :math:`\log\sigma^2`, which are used for the reparametrization trick. """ self.__mu__, self.__logstd__ = self.encoder(*args, **filter_kwargs(self.encoder.forward, kwargs)) self.__logstd__ = self.__logstd__.clamp(max=MAX_LOGSTD) z = self.reparametrize(self.__mu__, self.__logstd__) return z def kl_loss( self, mu: Optional[Tensor] = None, logstd: Optional[Tensor] = None, ) -> Tensor: r"""Computes the KL loss, either for the passed arguments :obj:`mu` and :obj:`logstd`, or based on latent variables from last encoding. Args: mu (torch.Tensor, optional): The latent space for :math:`\mu`. If set to :obj:`None`, uses the last computation of :math:`\mu`. (default: :obj:`None`) logstd (torch.Tensor, optional): The latent space for :math:`\log\sigma`. If set to :obj:`None`, uses the last computation of :math:`\log\sigma^2`. (default: :obj:`None`) """ mu = self.__mu__ if mu is None else mu logstd = self.__logstd__ if logstd is None else logstd.clamp( max=MAX_LOGSTD) return -0.5 * torch.mean( torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1)) def loss(self, x: Tensor, edge_index: Tensor, **kwargs) -> LossOutput: r""" Computes the ARGVA encoder loss with reconstruction, KL divergence, and regularization components. Args: x (torch.Tensor): Node features. edge_index (torch.Tensor): Edge indices (positive edges). Returns: LossOutput containing total loss and individual components. """ z = self.embed(x, edge_index, **kwargs) recon = self.recon_loss(z, edge_index) kl = self.kl_loss() reg = self.reg_loss(z) return LossOutput( total=recon + kl + reg, components={ 'recon': recon.item(), 'kl': kl.item(), 'reg': reg.item() } ) def loss_batch(self, batch: Data) -> LossOutput: r""" Computes encoder loss for a mini-batch with seed node slicing. Args: batch (Data): A mini-batch from the loader. Returns: LossOutput containing total loss and individual components. """ z = self.embed(batch.x, batch.edge_index) z = z[:batch.batch_size] mu = self.__mu__[:batch.batch_size] logstd = self.__logstd__[:batch.batch_size] # Extract edges within the batch batch_mask = (batch.edge_index[0] < batch.batch_size) & (batch.edge_index[1] < batch.batch_size) batch_edge_index = batch.edge_index[:, batch_mask] recon = self.recon_loss(z, batch_edge_index) kl = self.kl_loss(mu, logstd) reg = self.reg_loss(z) return LossOutput( total=recon + kl + reg, components={ 'recon': recon.item(), 'kl': kl.item(), 'reg': reg.item() } )