Source code for pyagc.encoders.sgformer

# This code is adapted from the following source:
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/sgformer.py


from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.attention import SGFormerAttention
from torch_geometric.nn.models.sgformer import GraphModule
from torch_geometric.utils import to_dense_batch


class SGModule(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        num_layers=2,
        num_heads=1,
        dropout=0.5,
    ):
        super().__init__()

        self.attns = torch.nn.ModuleList()
        self.fcs = torch.nn.ModuleList()
        self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.LayerNorm(hidden_channels))
        for _ in range(num_layers):
            self.attns.append(
                SGFormerAttention(hidden_channels, num_heads, hidden_channels))
            self.bns.append(torch.nn.LayerNorm(hidden_channels))

        self.dropout = dropout
        self.activation = F.relu

    def reset_parameters(self):
        for attn in self.attns:
            attn.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()

    def forward(self, x: Tensor, batch: Optional[Tensor] = None):
        # If batch is provided, sort it as to_dense_batch requires sorted batch;
        # if batch is None, to_dense_batch treats all nodes as a single graph.
        if batch is not None:
            batch, indices = batch.sort(stable=True)
            rev_perm = torch.empty_like(indices)
            rev_perm[indices] = torch.arange(len(indices), device=indices.device)
            x = x[indices]

        x, mask = to_dense_batch(x, batch)
        layer_ = []

        x = self.fcs[0](x)
        x = self.bns[0](x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        layer_.append(x)

        for i, attn in enumerate(self.attns):
            x = attn(x, mask)
            x = (x + layer_[i]) / 2.
            x = self.bns[i + 1](x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            layer_.append(x)

        x_mask = x[mask]

        # Reverse the sorting only if reordering was applied
        if batch is not None:
            x_mask = x_mask[rev_perm]

        return x_mask


[docs]class SGFormer(torch.nn.Module): r"""The sgformer module from the `"SGFormer: Simplifying and Empowering Transformers for Large-Graph Representations" <https://arxiv.org/abs/2306.10759>`_ paper. SGFormer integrates a **global attention module** and a **GNN module** to jointly capture: - global all-pair node interactions (Transformer-style attention) - local structural information (GNN message passing) **1. Simplified Global Attention** Given input node features :math:`Z^{(0)} \in \mathbb{R}^{N \times d}`: .. math:: Q = f_Q(Z^{(0)}), \quad K = f_K(Z^{(0)}), \quad V = f_V(Z^{(0)}) Normalize: .. math:: \tilde{Q} = \frac{Q}{\|Q\|_F}, \quad \tilde{K} = \frac{K}{\|K\|_F} Define diagonal normalization: .. math:: D = \operatorname{diag}\left(1 + \frac{1}{N} \tilde{Q}(\tilde{K}^\top \mathbf{1}) \right) The attention output is: .. math:: Z = \beta D^{-1} \left( V + \frac{1}{N} \tilde{Q}(\tilde{K}^\top V) \right) + (1 - \beta) Z^{(0)} This formulation achieves **linear complexity :math:`O(N)`** compared to :math:`O(N^2)` in standard Transformers :contentReference[oaicite:0]{index=0}. **2. GNN-based Local Propagation** Structural information is incorporated via a GNN: .. math:: Z_{\text{gnn}} = \mathrm{GN}(Z^{(0)}, A) where :math:`A` is the adjacency matrix. **3. Aggregation Strategy** The global and local representations are combined as: **(a) Weighted sum (add):** .. math:: Z_{\text{out}} = (1 - \alpha) Z + \alpha Z_{\text{gnn}} **(b) Concatenation (cat):** .. math:: Z_{\text{out}} = [Z \, \| \, Z_{\text{gnn}}] **4. Output Layer** .. math:: \hat{Y} = f_O(Z_{\text{out}}) where :math:`f_O` is a linear projection. Args: in_channels (int): Input channels. hidden_channels (int): Hidden channels. out_channels (int): Output channels. trans_num_layers (int): The number of layers for all-pair attention. (default: :obj:`2`) trans_num_heads (int): The number of heads for attention. (default: :obj:`1`) trans_dropout (float): Global dropout rate. (default: :obj:`0.5`) gnn_num_layers (int): The number of layers for GNN. (default: :obj:`3`) gnn_dropout (float): GNN dropout rate. (default: :obj:`0.5`) graph_weight (float): The weight balance global and gnn module. (default: :obj:`0.5`) aggregate (str): Aggregate type. (default: :obj:`add`) """
[docs] def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, trans_num_layers: int = 2, trans_num_heads: int = 1, trans_dropout: float = 0.5, gnn_num_layers: int = 3, gnn_dropout: float = 0.5, graph_weight: float = 0.5, aggregate: str = 'add', ): super().__init__() self.trans_conv = SGModule( in_channels, hidden_channels, trans_num_layers, trans_num_heads, trans_dropout, ) self.graph_conv = GraphModule( in_channels, hidden_channels, gnn_num_layers, gnn_dropout, ) self.graph_weight = graph_weight self.aggregate = aggregate if aggregate == 'add': self.fc = torch.nn.Linear(hidden_channels, out_channels) elif aggregate == 'cat': self.fc = torch.nn.Linear(2 * hidden_channels, out_channels) else: raise ValueError(f'Invalid aggregate type:{aggregate}') self.params1 = list(self.trans_conv.parameters()) self.params2 = list(self.graph_conv.parameters()) self.params2.extend(list(self.fc.parameters())) self.out_channels = out_channels
[docs] def reset_parameters(self) -> None: self.trans_conv.reset_parameters() self.graph_conv.reset_parameters() self.fc.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Tensor, batch: Optional[Tensor] = None, ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. """ x1 = self.trans_conv(x, batch) x2 = self.graph_conv(x, edge_index) if self.aggregate == 'add': x = self.graph_weight * x2 + (1 - self.graph_weight) * x1 else: x = torch.cat((x1, x2), dim=1) return self.fc(x)