pyagc.encoders.create_tuned_gnn
- create_tuned_gnn(gnn_type: str, in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_last: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, residual: bool = False, pre_linear: bool = False, jk: Optional[str] = None, **kwargs) TunedGNN[source]
Factory function to create tuned GNN models with recommended defaults.
This function provides an easy way to create tuned GNN models with hyperparameters optimized based on empirical findings from the “Classic GNNs are Strong Baselines: Reassessing GNNs for Node Classification” paper (Luo et al., NeurIPS 2024).
The function automatically filters out incompatible parameters for each GNN type by inspecting the model’s signature, so you can safely pass all parameters without worrying about compatibility.
- Parameters:
gnn_type (str) – Type of GNN. Options:
"gcn","sage","gat","gatv2","gin","pna","edgecnn".in_channels (int) – Size of each input sample.
hidden_channels (int) – Size of each hidden sample.
num_layers (int) – Number of message passing layers. Recommendation: 2-6 for homophilous graphs, 6-15 for heterophilous.
out_channels (int, optional) – Output size. If not set, will use
hidden_channels. (default:None)dropout (float, optional) – Dropout probability. Paper findings suggest 0.2-0.7 range works well. (default:
0.0)act (str or Callable, optional) – The non-linear activation function to use. (default:
"relu")act_first (bool, optional) – If set to
True, activation is applied before normalization. (default:False)act_last (bool, optional) – If set to
True, applies activation function to the final output. Useful for tasks requiring non-linear final representations. (default:False)act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by
act. (default:None)norm (str or Callable, optional) – Normalization type. Options:
"batch_norm","layer_norm". Paper recommends BatchNorm for large graphs, LayerNorm for smaller graphs. (default:None)norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by
norm. (default:None)residual (bool, optional) – If set to
True, applies residual connections. Especially beneficial for heterophilous graphs and deeper networks. (default:False)pre_linear (bool, optional) – If set to
True, applies a linear transformation before the first GNN layer. (default:False)jk (str, optional) – Jumping Knowledge mode. Options:
None,"last","cat","max","lstm". Paper shows this is optional but can help in some cases. (default:None)**kwargs –
Additional GNN-specific arguments. These will be automatically filtered based on the GNN type. Common options include:
heads(int): Number of attention heads (GAT/GATv2 only)concat(bool): Concatenate attention heads (GAT/GATv2 only)v2(bool): Use GATv2 variant (GAT only, auto-set for gatv2)add_self_loops(bool): Add self-loops to adjacency matrixnormalize(bool): Apply normalization (GCN only)improved(bool): Use improved GCN formulation (GCN only)cached(bool): Cache normalized edge weights (GCN only)bias(bool): Add bias parametersaggr(str): Aggregation scheme (e.g., “mean”, “max”, “add”)aggregators(List[str]): Aggregation functions (PNA only)scalers(List[str]): Scaling functions (PNA only)deg(Tensor): Degree histogram for normalization (PNA only)edge_dim(int): Edge feature dimensionality (GAT/GATv2/EdgeCNN)fill_value(float or str): Value for self-loops
- Returns:
The initialized tuned GNN model.
- Return type:
Examples
>>> # Create a tuned GCN for homophilous graphs >>> model = create_tuned_gnn( ... 'gcn', in_channels=128, hidden_channels=256, ... num_layers=3, out_channels=10, dropout=0.5, ... norm='batch_norm' ... )
>>> # Create a tuned GCN for heterophilous graphs (deeper + residual) >>> model = create_tuned_gnn( ... 'gcn', in_channels=128, hidden_channels=256, ... num_layers=10, out_channels=10, dropout=0.5, ... norm='batch_norm', residual=True, pre_linear=True ... )
>>> # Create a tuned GAT with multiple attention heads >>> model = create_tuned_gnn( ... 'gat', in_channels=128, hidden_channels=256, ... num_layers=3, out_channels=10, heads=4, concat=True, ... dropout=0.6, norm='layer_norm' ... )
>>> # Create a tuned model with custom activation >>> model = create_tuned_gnn( ... 'sage', in_channels=128, hidden_channels=256, ... num_layers=3, act='elu', act_first=True, ... norm='batch_norm', residual=True ... )
>>> # Create a model with Jumping Knowledge >>> model = create_tuned_gnn( ... 'gcn', in_channels=128, hidden_channels=256, ... num_layers=4, out_channels=10, jk='cat', ... norm='layer_norm' ... )
>>> # Pass all parameters - incompatible ones are automatically filtered >>> model = create_tuned_gnn( ... 'gcn', in_channels=128, hidden_channels=256, ... num_layers=3, heads=4 # 'heads' will be ignored for GCN ... )