Source code for pyagc.encoders.tuned_gnn

# Adapted from:
# 1. https://github.com/LUOyk1999/tunedGNN/blob/main/medium_graph/model.py
# 2. https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/basic_gnn.py

import copy
import inspect
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Linear, ModuleList
from tqdm import tqdm

from torch_geometric.data import Data
from torch_geometric.loader import CachedLoader, NeighborLoader
from torch_geometric.nn.conv import (
    EdgeConv,
    GATConv,
    GATv2Conv,
    GCNConv,
    GINConv,
    MessagePassing,
    PNAConv,
    SAGEConv,
)
from torch_geometric.nn.models import MLP
from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils._trim_to_layer import TrimToLayer


[docs]class TunedGNN(torch.nn.Module): r"""An enhanced GNN model with tuned hyperparameters based on `"Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification" <https://arxiv.org/abs/2406.08993>`_ paper (Luo et al., NeurIPS 2024). This implementation incorporates critical improvements identified in the paper: - Residual connections for deeper networks and heterophilous graphs - Pre-linear transformation option - Flexible normalization (LayerNorm/BatchNorm) - Optimized dropout strategies - Support for deeper architectures (up to 10-15 layers) Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_last (bool, optional): If set to :obj:`True`, applies activation function to the final output. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function. Recommended: :obj:`"batch_norm"` for large graphs, :obj:`"layer_norm"` for smaller graphs. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) residual (bool, optional): If set to :obj:`True`, applies residual connections. Especially beneficial for heterophilous graphs. (default: :obj:`False`) pre_linear (bool, optional): If set to :obj:`True`, applies a linear transformation before the first GNN layer. (default: :obj:`False`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of the underlying :class:`torch_geometric.nn.conv.MessagePassing` layers. """ supports_edge_weight: Final[bool] supports_edge_attr: Final[bool] supports_norm_batch: Final[bool]
[docs] def __init__( self, in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Union[str, Callable, None] = "relu", act_first: bool = False, act_last: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Union[str, Callable, None] = None, norm_kwargs: Optional[Dict[str, Any]] = None, residual: bool = False, pre_linear: bool = False, jk: Optional[str] = None, **kwargs, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.num_layers = num_layers self.residual = residual self.pre_linear = pre_linear self.dropout = torch.nn.Dropout(p=dropout) self.act = activation_resolver(act, **(act_kwargs or {})) self.jk_mode = jk self.act_first = act_first self.act_last = act_last self.norm = norm if isinstance(norm, str) else None self.norm_kwargs = norm_kwargs if out_channels is not None: self.out_channels = out_channels else: self.out_channels = hidden_channels # Pre-linear transformation (optional) if self.pre_linear: self.lin_in = Linear(in_channels, hidden_channels) conv_in_channels = hidden_channels else: conv_in_channels = in_channels # Initialize convolutional layers self.convs = ModuleList() # First layer if num_layers > 1: self.convs.append( self.init_conv(conv_in_channels, hidden_channels, **kwargs)) if isinstance(conv_in_channels, (tuple, list)): conv_in_channels = (hidden_channels, hidden_channels) else: conv_in_channels = hidden_channels # Hidden layers for _ in range(num_layers - 2): self.convs.append( self.init_conv(conv_in_channels, hidden_channels, **kwargs)) if isinstance(conv_in_channels, (tuple, list)): conv_in_channels = (hidden_channels, hidden_channels) else: conv_in_channels = hidden_channels # Last layer if out_channels is not None and jk is None: self._is_conv_to_out = True self.convs.append( self.init_conv(conv_in_channels, out_channels, **kwargs)) else: self.convs.append( self.init_conv(conv_in_channels, hidden_channels, **kwargs)) # Residual connection linear layers if self.residual: self.res_lins = ModuleList() # Handle pre-linear case if self.pre_linear: res_in = hidden_channels else: res_in = in_channels # First residual layer if num_layers > 1: if isinstance(res_in, int): self.res_lins.append(Linear(res_in, hidden_channels)) else: self.res_lins.append(Linear(res_in[0], hidden_channels)) res_in = hidden_channels # Hidden residual layers for _ in range(num_layers - 2): self.res_lins.append(Linear(res_in, hidden_channels)) # Last residual layer if out_channels is not None and jk is None: self.res_lins.append(Linear(res_in, out_channels)) else: self.res_lins.append(Linear(res_in, hidden_channels)) # Normalization layers self.norms = ModuleList() norm_layer = normalization_resolver( norm, hidden_channels, **(norm_kwargs or {}), ) if norm_layer is None: norm_layer = torch.nn.Identity() self.supports_norm_batch = False if hasattr(norm_layer, 'forward'): norm_params = inspect.signature(norm_layer.forward).parameters self.supports_norm_batch = 'batch' in norm_params for _ in range(num_layers - 1): self.norms.append(copy.deepcopy(norm_layer)) if jk is not None: self.norms.append(copy.deepcopy(norm_layer)) else: self.norms.append(torch.nn.Identity()) # Jumping Knowledge if jk is not None and jk != 'last': self.jk = JumpingKnowledge(jk, hidden_channels, num_layers) if jk is not None: if jk == 'cat': in_channels_jk = num_layers * hidden_channels else: in_channels_jk = hidden_channels self.lin = Linear(in_channels_jk, self.out_channels) # We define `trim_to_layer` functionality as a module such that we can # still use `to_hetero` on-top. self._trim = TrimToLayer()
[docs] def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: raise NotImplementedError
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" if self.pre_linear: self.lin_in.reset_parameters() for conv in self.convs: conv.reset_parameters() if self.residual: for res_lin in self.res_lins: res_lin.reset_parameters() for norm in self.norms: if hasattr(norm, 'reset_parameters'): norm.reset_parameters() if hasattr(self, 'jk'): self.jk.reset_parameters() if hasattr(self, 'lin'): self.lin.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, edge_attr: OptTensor = None, batch: OptTensor = None, batch_size: Optional[int] = None, num_sampled_nodes_per_hop: Optional[List[int]] = None, num_sampled_edges_per_hop: Optional[List[int]] = None, ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_weight (torch.Tensor, optional): The edge weights (if supported by the underlying GNN layer). (default: :obj:`None`) edge_attr (torch.Tensor, optional): The edge features (if supported by the underlying GNN layer). (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. Only needs to be passed in case the underlying normalization layers require the :obj:`batch` information. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. Only needs to be passed in case the underlying normalization layers require the :obj:`batch` information. (default: :obj:`None`) num_sampled_nodes_per_hop (List[int], optional): The number of sampled nodes per hop. Useful in :class:`~torch_geometric.loader.NeighborLoader` scenarios to only operate on minimal-sized representations. (default: :obj:`None`) num_sampled_edges_per_hop (List[int], optional): The number of sampled edges per hop. Useful in :class:`~torch_geometric.loader.NeighborLoader` scenarios to only operate on minimal-sized representations. (default: :obj:`None`) """ if (num_sampled_nodes_per_hop is not None and isinstance(edge_weight, Tensor) and isinstance(edge_attr, Tensor)): raise NotImplementedError("'trim_to_layer' functionality does not " "yet support trimming of both " "'edge_weight' and 'edge_attr'") # Pre-linear transformation if self.pre_linear: x = self.lin_in(x) x = self.dropout(x) xs: List[Tensor] = [] for i, (conv, norm) in enumerate(zip(self.convs, self.norms)): # Trim to layer for mini-batch training if (not torch.jit.is_scripting() and num_sampled_nodes_per_hop is not None): x, edge_index, value = self._trim( i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop, x, edge_index, edge_weight if edge_weight is not None else edge_attr, ) if edge_weight is not None: edge_weight = value else: edge_attr = value # Store input for residual connection x_res = x # Convolution if self.supports_edge_weight and self.supports_edge_attr: x = conv(x, edge_index, edge_weight=edge_weight, edge_attr=edge_attr) elif self.supports_edge_weight: x = conv(x, edge_index, edge_weight=edge_weight) elif self.supports_edge_attr: x = conv(x, edge_index, edge_attr=edge_attr) else: x = conv(x, edge_index) # Residual connection if self.residual: x = x + self.res_lins[i](x_res) # Apply normalization and activation for all layers except potentially the last if i < self.num_layers - 1 or self.jk_mode is not None: if self.act is not None and self.act_first: x = self.act(x) if self.supports_norm_batch: x = norm(x, batch, batch_size) else: x = norm(x) if self.act is not None and not self.act_first: x = self.act(x) x = self.dropout(x) if hasattr(self, 'jk'): xs.append(x) # Jumping Knowledge aggregation x = self.jk(xs) if hasattr(self, 'jk') else x # Final linear transformation x = self.lin(x) if hasattr(self, 'lin') else x # Apply activation to final output if requested if self.act is not None and self.act_last : x = self.act(x) return x
[docs] @torch.no_grad() def inference_per_layer( self, layer: int, x: Tensor, edge_index: Adj, batch_size: int, ) -> Tensor: """Inference for a single layer.""" x_res = x x = self.convs[layer](x, edge_index)[:batch_size] if self.residual: x = x + self.res_lins[layer](x_res)[:batch_size] if layer == self.num_layers - 1 and self.jk_mode is None: return x if self.act is not None and self.act_first: x = self.act(x) if self.norms is not None: x = self.norms[layer](x) if self.act is not None and not self.act_first: x = self.act(x) if layer == self.num_layers - 1: if hasattr(self, 'lin'): x = self.lin(x) # Apply act_last after final linear transformation if self.act is not None and self.act_last: x = self.act(x) return x
[docs] @torch.no_grad() def inference( self, loader: NeighborLoader, device: Optional[Union[str, torch.device]] = None, embedding_device: Union[str, torch.device] = 'cpu', progress_bar: bool = False, cache: bool = False, ) -> Tensor: r"""Performs layer-wise inference on large-graphs using a :class:`~torch_geometric.loader.NeighborLoader`, where :class:`~torch_geometric.loader.NeighborLoader` should sample the full neighborhood for only one layer. This is an efficient way to compute the output embeddings for all nodes in the graph. Only applicable in case :obj:`jk=None` or `jk='last'`. Args: loader (torch_geometric.loader.NeighborLoader): A neighbor loader object that generates full 1-hop subgraphs, *i.e.*, :obj:`loader.num_neighbors = [-1]`. device (torch.device, optional): The device to run the GNN on. (default: :obj:`None`) embedding_device (torch.device, optional): The device to store intermediate embeddings on. If intermediate embeddings fit on GPU, this option helps to avoid unnecessary device transfers. (default: :obj:`"cpu"`) progress_bar (bool, optional): If set to :obj:`True`, will print a progress bar during computation. (default: :obj:`False`) cache (bool, optional): If set to :obj:`True`, caches intermediate sampler outputs for usage in later epochs. This will avoid repeated sampling to accelerate inference. (default: :obj:`False`) """ assert self.jk_mode is None or self.jk_mode == 'last' assert isinstance(loader, NeighborLoader) assert len(loader.dataset) == loader.data.num_nodes assert len(loader.node_sampler.num_neighbors) == 1 assert not self.training # assert not loader.shuffle # TODO (matthias) does not work :( if progress_bar: pbar = tqdm(total=len(self.convs) * len(loader)) pbar.set_description('Inference') x_all = loader.data.x.to(embedding_device) # Pre-linear transformation if self.pre_linear: x_all = self.lin_in(x_all) if cache: # Only cache necessary attributes: def transform(data: Data) -> Data: kwargs = dict(n_id=data.n_id, batch_size=data.batch_size) if hasattr(data, 'adj_t'): kwargs['adj_t'] = data.adj_t else: kwargs['edge_index'] = data.edge_index return Data.from_dict(kwargs) loader = CachedLoader(loader, device=device, transform=transform) for i in range(self.num_layers): xs: List[Tensor] = [] for batch in loader: x = x_all[batch.n_id].to(device) batch_size = batch.batch_size if hasattr(batch, 'adj_t'): edge_index = batch.adj_t.to(device) else: edge_index = batch.edge_index.to(device) x = self.inference_per_layer(i, x, edge_index, batch_size) xs.append(x.to(embedding_device)) if progress_bar: pbar.update(1) x_all = torch.cat(xs, dim=0) if progress_bar: pbar.close() return x_all
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_layers={self.num_layers}, ' f'residual={self.residual})')
[docs]class TunedGCN(TunedGNN): r"""Tuned Graph Convolutional Network based on `"Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification" <https://arxiv.org/abs/2406.08993>`_ paper (Luo et al., NeurIPS 2024). Key improvements over standard GCN: - Residual connections (especially beneficial for deep networks) - Flexible normalization (BatchNorm/LayerNorm) - Optimized dropout strategies - Optional pre-linear transformation Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function. Recommended: :obj:`"batch_norm"` for large graphs, :obj:`"layer_norm"` for smaller graphs. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) residual (bool, optional): If set to :obj:`True`, applies residual connections. Especially beneficial for heterophilous graphs. (default: :obj:`False`) pre_linear (bool, optional): Apply linear transformation before first GNN layer. (default: :obj:`False`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality, while default will not. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GCNConv`. """ supports_edge_weight: Final[bool] = True supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool]
[docs] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: return GCNConv(in_channels, out_channels, **kwargs)
[docs]class TunedGraphSAGE(TunedGNN): r"""Tuned GraphSAGE Network based on `"Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification" <https://arxiv.org/abs/2406.08993>`_ paper (Luo et al., NeurIPS 2024). Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function. Recommended: :obj:`"batch_norm"` for large graphs, :obj:`"layer_norm"` for smaller graphs. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) residual (bool, optional): If set to :obj:`True`, applies residual connections. Especially beneficial for heterophilous graphs. (default: :obj:`False`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.SAGEConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool]
[docs] def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: return SAGEConv(in_channels, out_channels, **kwargs)
[docs]class TunedGIN(TunedGNN): r"""Tuned Graph Isomorphism Network based on `"Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification" <https://arxiv.org/abs/2406.08993>`_ paper (Luo et al., NeurIPS 2024). Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function. Recommended: :obj:`"batch_norm"` for large graphs, :obj:`"layer_norm"` for smaller graphs. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) residual (bool, optional): If set to :obj:`True`, applies residual connections. Especially beneficial for heterophilous graphs. (default: :obj:`False`) pre_linear (bool, optional): Apply linear transformation before first GNN layer. (default: :obj:`False`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GINConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool]
[docs] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: mlp = MLP( [in_channels, out_channels, out_channels], act=self.act, act_first=self.act_first, norm=self.norm, norm_kwargs=self.norm_kwargs, ) return GINConv(mlp, **kwargs)
[docs]class TunedGAT(TunedGNN): r"""Tuned Graph Attention Network based on `"Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification" <https://arxiv.org/abs/2406.08993>`_ paper (Luo et al., NeurIPS 2024). Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) v2 (bool, optional): If set to :obj:`True`, will make use of :class:`~torch_geometric.nn.conv.GATv2Conv` rather than :class:`~torch_geometric.nn.conv.GATConv`. (default: :obj:`False`) heads (int, optional): Number of attention heads. (default: :obj:`1`) concat (bool, optional): Concatenate attention heads. (default: :obj:`True`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function. Recommended: :obj:`"batch_norm"` for large graphs, :obj:`"layer_norm"` for smaller graphs. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) residual (bool, optional): If set to :obj:`True`, applies residual connections. Especially beneficial for heterophilous graphs. (default: :obj:`False`) pre_linear (bool, optional): Apply linear transformation before first GNN layer. (default: :obj:`False`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GATConv` or :class:`torch_geometric.nn.conv.GATv2Conv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = True supports_norm_batch: Final[bool]
[docs] def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: v2 = kwargs.pop('v2', False) heads = kwargs.pop('heads', 1) concat = kwargs.pop('concat', True) # Do not use concatenation in case the layer `GATConv` layer maps to # the desired output channels (out_channels != None and jk != None): if getattr(self, '_is_conv_to_out', False): concat = False if concat and out_channels % heads != 0: raise ValueError(f"Ensure that the number of output channels of " f"'GATConv' (got '{out_channels}') is divisible " f"by the number of heads (got '{heads}')") if concat: out_channels = out_channels // heads Conv = GATConv if not v2 else GATv2Conv return Conv(in_channels, out_channels, heads=heads, concat=concat, dropout=self.dropout.p, **kwargs)
[docs]class TunedPNA(TunedGNN): r"""Tuned Principal Neighbourhood Aggregation Network based on `"Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification" <https://arxiv.org/abs/2406.08993>`_ paper (Luo et al., NeurIPS 2024). Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function. Recommended: :obj:`"batch_norm"` for large graphs, :obj:`"layer_norm"` for smaller graphs. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) residual (bool, optional): If set to :obj:`True`, applies residual connections. Especially beneficial for heterophilous graphs. (default: :obj:`False`) pre_linear (bool, optional): Apply linear transformation before first GNN layer. (default: :obj:`False`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.PNAConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = True supports_norm_batch: Final[bool]
[docs] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: return PNAConv(in_channels, out_channels, **kwargs)
[docs]class TunedEdgeCNN(TunedGNN): r"""Tuned EdgeCNN (Dynamic Graph CNN) based on `"Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification" <https://arxiv.org/abs/2406.08993>`_ paper (Luo et al., NeurIPS 2024). Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function. Recommended: :obj:`"batch_norm"` for large graphs, :obj:`"layer_norm"` for smaller graphs. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) residual (bool, optional): If set to :obj:`True`, applies residual connections. Especially beneficial for heterophilous graphs. (default: :obj:`False`) pre_linear (bool, optional): Apply linear transformation before first GNN layer. (default: :obj:`False`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.EdgeConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool]
[docs] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: mlp = MLP( [2 * in_channels, out_channels, out_channels], act=self.act, act_first=self.act_first, norm=self.norm, norm_kwargs=self.norm_kwargs, ) return EdgeConv(mlp, **kwargs)
# Factory function for convenient model creation
[docs]def create_tuned_gnn( gnn_type: str, in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Union[str, Callable, None] = "relu", act_first: bool = False, act_last: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Union[str, Callable, None] = None, norm_kwargs: Optional[Dict[str, Any]] = None, residual: bool = False, pre_linear: bool = False, jk: Optional[str] = None, **kwargs ) -> TunedGNN: r"""Factory function to create tuned GNN models with recommended defaults. This function provides an easy way to create tuned GNN models with hyperparameters optimized based on empirical findings from the `"Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification" <https://arxiv.org/abs/2406.08993>`_ paper (Luo et al., NeurIPS 2024). The function automatically filters out incompatible parameters for each GNN type by inspecting the model's signature, so you can safely pass all parameters without worrying about compatibility. Args: gnn_type (str): Type of GNN. Options: :obj:`"gcn"`, :obj:`"sage"`, :obj:`"gat"`, :obj:`"gatv2"`, :obj:`"gin"`, :obj:`"pna"`, :obj:`"edgecnn"`. in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. Recommendation: 2-6 for homophilous graphs, 6-15 for heterophilous. out_channels (int, optional): Output size. If not set, will use :obj:`hidden_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. Paper findings suggest 0.2-0.7 range works well. (default: :obj:`0.0`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_last (bool, optional): If set to :obj:`True`, applies activation function to the final output. Useful for tasks requiring non-linear final representations. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): Normalization type. Options: :obj:`"batch_norm"`, :obj:`"layer_norm"`. Paper recommends BatchNorm for large graphs, LayerNorm for smaller graphs. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) residual (bool, optional): If set to :obj:`True`, applies residual connections. Especially beneficial for heterophilous graphs and deeper networks. (default: :obj:`False`) pre_linear (bool, optional): If set to :obj:`True`, applies a linear transformation before the first GNN layer. (default: :obj:`False`) jk (str, optional): Jumping Knowledge mode. Options: :obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`. Paper shows this is optional but can help in some cases. (default: :obj:`None`) **kwargs: Additional GNN-specific arguments. These will be automatically filtered based on the GNN type. Common options include: - :obj:`heads` (int): Number of attention heads (GAT/GATv2 only) - :obj:`concat` (bool): Concatenate attention heads (GAT/GATv2 only) - :obj:`v2` (bool): Use GATv2 variant (GAT only, auto-set for gatv2) - :obj:`add_self_loops` (bool): Add self-loops to adjacency matrix - :obj:`normalize` (bool): Apply normalization (GCN only) - :obj:`improved` (bool): Use improved GCN formulation (GCN only) - :obj:`cached` (bool): Cache normalized edge weights (GCN only) - :obj:`bias` (bool): Add bias parameters - :obj:`aggr` (str): Aggregation scheme (e.g., "mean", "max", "add") - :obj:`aggregators` (List[str]): Aggregation functions (PNA only) - :obj:`scalers` (List[str]): Scaling functions (PNA only) - :obj:`deg` (Tensor): Degree histogram for normalization (PNA only) - :obj:`edge_dim` (int): Edge feature dimensionality (GAT/GATv2/EdgeCNN) - :obj:`fill_value` (float or str): Value for self-loops Returns: TunedGNN: The initialized tuned GNN model. Examples: >>> # Create a tuned GCN for homophilous graphs >>> model = create_tuned_gnn( ... 'gcn', in_channels=128, hidden_channels=256, ... num_layers=3, out_channels=10, dropout=0.5, ... norm='batch_norm' ... ) >>> # Create a tuned GCN for heterophilous graphs (deeper + residual) >>> model = create_tuned_gnn( ... 'gcn', in_channels=128, hidden_channels=256, ... num_layers=10, out_channels=10, dropout=0.5, ... norm='batch_norm', residual=True, pre_linear=True ... ) >>> # Create a tuned GAT with multiple attention heads >>> model = create_tuned_gnn( ... 'gat', in_channels=128, hidden_channels=256, ... num_layers=3, out_channels=10, heads=4, concat=True, ... dropout=0.6, norm='layer_norm' ... ) >>> # Create a tuned model with custom activation >>> model = create_tuned_gnn( ... 'sage', in_channels=128, hidden_channels=256, ... num_layers=3, act='elu', act_first=True, ... norm='batch_norm', residual=True ... ) >>> # Create a model with Jumping Knowledge >>> model = create_tuned_gnn( ... 'gcn', in_channels=128, hidden_channels=256, ... num_layers=4, out_channels=10, jk='cat', ... norm='layer_norm' ... ) >>> # Pass all parameters - incompatible ones are automatically filtered >>> model = create_tuned_gnn( ... 'gcn', in_channels=128, hidden_channels=256, ... num_layers=3, heads=4 # 'heads' will be ignored for GCN ... ) """ gnn_type = gnn_type.lower() model_map = { 'gcn': TunedGCN, 'sage': TunedGraphSAGE, 'graphsage': TunedGraphSAGE, 'gat': TunedGAT, 'gatv2': TunedGAT, 'gin': TunedGIN, 'pna': TunedPNA, 'edgecnn': TunedEdgeCNN, } if gnn_type not in model_map: raise ValueError(f"Unknown GNN type: {gnn_type}. " f"Available options: {list(model_map.keys())}") model_class = model_map[gnn_type] # Get valid parameters by inspecting the model class __init__ signature init_signature = inspect.signature(model_class.__init__) valid_params = set(init_signature.parameters.keys()) - {'self'} # Build the complete parameter dictionary all_params = { 'in_channels': in_channels, 'hidden_channels': hidden_channels, 'num_layers': num_layers, 'out_channels': out_channels, 'dropout': dropout, 'act': act, 'act_first': act_first, 'act_last': act_last, 'act_kwargs': act_kwargs, 'norm': norm, 'norm_kwargs': norm_kwargs, 'residual': residual, 'pre_linear': pre_linear, 'jk': jk, } # Add kwargs all_params.update(kwargs) # GATv2-specific handling (set v2=True before filtering) if gnn_type == 'gatv2': all_params['v2'] = True # Filter to only include valid parameters for this model class filtered_params = {k: v for k, v in all_params.items() if k in valid_params} # Optional: Log filtered parameters for debugging filtered_out = set(all_params.keys()) - set(filtered_params.keys()) if filtered_out: import warnings warnings.warn( f"The following parameters are not applicable to {gnn_type.upper()} " f"and will be ignored: {filtered_out}", UserWarning, stacklevel=2 ) return model_class(**filtered_params)
__all__ = [ 'TunedGNN', 'TunedGCN', 'TunedGraphSAGE', 'TunedGAT', 'TunedGIN', 'TunedPNA', 'TunedEdgeCNN', 'create_tuned_gnn', ]