Source code for pyagc.encoders.tabencoder

import inspect
from typing import Any, Dict, List, Optional

import torch
import torch_frame
from torch import Tensor
from torch.nn import Module
from torch_frame import stype
from torch_frame.data.stats import StatType
from torch_frame.nn.models import ResNet
from torch_geometric.typing import Adj, OptTensor


[docs]class TabularEncoder(Module): r""" Tabular encoder using PyTorch Frame. It maps a single TensorFrame into embeddings. Args: channels (int): Output embedding dimension. col_names_dict (Dict[torch_frame.stype, List[str]]): A mapping from stype → column names. col_stats (Dict[str, Dict[StatType, Any]]): Column statistics computed from the training set only. torch_frame_model_cls (defaults to ResNet): TorchFrame encoder class to use. torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for :class:`torch_frame_model_cls` class. Default keyword argument is set specific for :class:`torch_frame.nn.ResNet`. Expect it to be changed for different :class:`torch_frame_model_cls`. default_stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]): A dictionary mapping from :obj:`torch_frame.stype` object into a tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its keyword arguments :obj:`kwargs`. """
[docs] def __init__( self, channels: int, col_names_dict: Dict[stype, List[str]], col_stats: Dict[str, Dict[StatType, Any]], torch_frame_model_cls=ResNet, torch_frame_model_kwargs: Dict[str, Any] = { "channels": 128, "num_layers": 2, }, default_stype_encoder_cls_kwargs: Dict[stype, Any] = { torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}), torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}), torch_frame.multicategorical: ( torch_frame.nn.MultiCategoricalEmbeddingEncoder, {}, ), torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}), torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), }, ): super().__init__() # Build stype → Encoder module dict stype_encoder_dict = { st: cls(**kwargs) for st, (cls, kwargs) in default_stype_encoder_cls_kwargs.items() if st in col_names_dict # only keep stypes present in this table } # The actual TorchFrame model (e.g., ResNet) self.encoder = torch_frame_model_cls( **torch_frame_model_kwargs, out_channels=channels, col_stats=col_stats, col_names_dict=col_names_dict, stype_encoder_dict=stype_encoder_dict, )
[docs] def reset_parameters(self): self.encoder.reset_parameters()
[docs] def forward(self, tf: torch_frame.TensorFrame) -> Tensor: """ Args: tf: TensorFrame. Returns: Tensor of shape [num_samples, channels]. """ return self.encoder(tf)
[docs]class TabularGraphEncoder(Module): r""" A two-stage encoder for Tabular Graphs: 1. Encode node tabular attributes with a :class:`TabularEncoder`. 2. Encode graph structure with a PyG GNN model. This module is useful when each node is associated with a row-like tabular feature representation (stored as a :class:`torch_frame.TensorFrame`) and graph connectivity should be exploited afterwards. Args: tabular_encoder (torch.nn.Module): A tabular encoder using PyTorch Frame. It maps a single TensorFrame into embeddings. graph_encoder (torch.nn.Module): A graph encoder that consumes node embeddings and graph connectivity. Typical examples are: :class:`torch_geometric.nn.models.GCN`, :class:`torch_geometric.nn.models.GraphSAGE`, etc. """
[docs] def __init__( self, tabular_encoder: Module, graph_encoder: Module, ): super().__init__() self.tabular_encoder = tabular_encoder self.graph_encoder = graph_encoder
[docs] def reset_parameters(self): """Reset parameters of both the tabular encoder and the graph encoder.""" if hasattr(self.tabular_encoder, "reset_parameters"): self.tabular_encoder.reset_parameters() if hasattr(self.graph_encoder, "reset_parameters"): self.graph_encoder.reset_parameters()
[docs] def encode_tabular(self, tf: torch_frame.TensorFrame) -> Tensor: r""" Encode node tabular attributes into dense node embeddings. Args: tf (torch_frame.TensorFrame): Node features in TensorFrame format. Returns: Tensor: Node embeddings of shape [num_nodes, channels]. """ return self.tabular_encoder(tf)
def _filter_supported_kwargs(self, module: Module, kwargs: Dict[str, Any]) -> Dict[str, Any]: r""" Filter keyword arguments according to the forward signature of a module. If the module forward accepts **kwargs, all arguments are kept. Otherwise, only explicitly declared parameters are retained. """ signature = inspect.signature(module.forward) parameters = signature.parameters accepts_var_kwargs = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values() ) if accepts_var_kwargs: return kwargs return { key: value for key, value in kwargs.items() if key in parameters }
[docs] def encode_graph( self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, edge_attr: OptTensor = None, batch: OptTensor = None, batch_size: Optional[int] = None, num_sampled_nodes_per_hop: Optional[List[int]] = None, num_sampled_edges_per_hop: Optional[List[int]] = None, ) -> Tensor: r""" Apply the graph encoder on node embeddings. Only arguments supported by the graph encoder's forward method will be passed through. """ kwargs = { "x": x, "edge_index": edge_index, "edge_weight": edge_weight, "edge_attr": edge_attr, "batch": batch, "batch_size": batch_size, "num_sampled_nodes_per_hop": num_sampled_nodes_per_hop, "num_sampled_edges_per_hop": num_sampled_edges_per_hop, } kwargs = self._filter_supported_kwargs(self.graph_encoder, kwargs) return self.graph_encoder(**kwargs)
[docs] def forward( self, x: torch_frame.TensorFrame, edge_index: Adj, edge_weight: OptTensor = None, edge_attr: OptTensor = None, batch: OptTensor = None, batch_size: Optional[int] = None, num_sampled_nodes_per_hop: Optional[List[int]] = None, num_sampled_edges_per_hop: Optional[List[int]] = None, ) -> Tensor: r""" Full forward pass: tabular node attributes -> tabular embeddings -> graph encoder output. """ x = self.encode_tabular(x) x = self.encode_graph( x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr, batch=batch, batch_size=batch_size, num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, num_sampled_edges_per_hop=num_sampled_edges_per_hop, ) return x
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"tabular_encoder={self.tabular_encoder.__class__.__name__}, " f"graph_encoder={self.graph_encoder.__class__.__name__})" )