Source code for pyagc.models.nafs

from typing import Optional, List, Union
import torch
from torch import Tensor
from torch_geometric.typing import Adj
from torch_geometric.utils import to_undirected, add_remaining_self_loops
import torch.nn.functional as F

from pyagc.models.base import BaseModel


[docs]class NAFS(BaseModel): r"""Node-Adaptive Feature Smoothing (NAFS) from the `"NAFS: A Simple yet Tough-to-beat Baseline for Graph Representation Learning" <https://arxiv.org/abs/2206.08583>`_ paper (Zhang et al., ICML 2022). NAFS is a training-free method that constructs node representations by adaptively smoothing node features over different propagation steps. It addresses two key limitations of GNNs: over-smoothing and scalability. The method consists of two operations: **(1) Node-Adaptive Feature Smoothing:** For each smoothing step :math:`k`, compute: .. math:: X^{(k)} = \hat{A}^k X where :math:`\hat{A} = \tilde{D}^{r-1} \tilde{A} \tilde{D}^{-r}` is the normalized adjacency matrix with parameter :math:`r`. Then adaptively combine them using smoothing weights: .. math:: \hat{X} = \sum_{k=0}^{K} W^{(k)} X^{(k)} where :math:`W^{(k)}` are diagonal weight matrices computed based on the over-smoothing distance: .. math:: D_i(k) = \text{dist}([X^{(k)}]_i, X_i) **(2) Feature Ensemble:** Combine smoothed features from different :math:`r` values: .. math:: Z = \bigoplus_{t=1}^{T} \hat{X}^{(t)} where :math:`\bigoplus` can be concatenation, mean, or max pooling. Args: K (int): Maximum number of smoothing steps. (default: :obj:`20`) r_list (List[float], optional): List of normalization parameters for feature ensemble. (default: :obj:`[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]`) ensemble (str, optional): Ensemble strategy, one of :obj:`'mean'`, :obj:`'max'`, or :obj:`'concat'`. (default: :obj:`'concat'`) distance (str, optional): Distance function for computing over-smoothing distance, one of :obj:`'cosine'` or :obj:`'euclidean'`. (default: :obj:`'cosine'`) Example: >>> from pyagc.models import NAFS >>> from pyagc.data import get_dataset >>> >>> # Load data >>> x, edge_index, y = get_dataset('Cora', root='./data') >>> >>> # Create model >>> model = NAFS(K=20, ensemble='concat') >>> >>> # Get embeddings >>> embeddings = model.embed(x, edge_index) >>> >>> # Use for clustering >>> from pyagc.clusters import KMeansClusterHead >>> kmeans = KMeansClusterHead(n_clusters=7) >>> pred = kmeans.fit_predict(embeddings) """ def __init__( self, K: int = 20, r_list: Optional[List[float]] = None, ensemble: str = 'concat', distance: str = 'cosine', ): super().__init__() self.K = K self.r_list = r_list if r_list is not None else [0.0, 0.1, 0.2, 0.3, 0.4, 0.5] self.ensemble = ensemble self.distance = distance if ensemble not in ('mean', 'max', 'concat'): raise ValueError(f"Invalid ensemble strategy: {ensemble}") if distance not in ('cosine', 'euclidean'): raise ValueError(f"Invalid distance function: {distance}") def _normalize_adj( self, edge_index: Adj, num_nodes: int, r: float = 0.5 ) -> Tensor: r"""Normalize adjacency matrix with parameter r. .. math:: \hat{A} = \tilde{D}^{r-1} \tilde{A} \tilde{D}^{-r} """ # Add self-loops edge_index = add_remaining_self_loops(edge_index, num_nodes=num_nodes)[0] edge_index = to_undirected(edge_index) # Compute degree row, col = edge_index deg = torch.zeros(num_nodes, device=edge_index.device) deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float)) # Compute normalization: D^{r-1} * A * D^{-r} deg_inv_left = deg.pow(r - 1) deg_inv_left[torch.isinf(deg_inv_left)] = 0.0 deg_inv_right = deg.pow(-r) deg_inv_right[torch.isinf(deg_inv_right)] = 0.0 # Create normalized sparse adjacency edge_weight = deg_inv_left[row] * deg_inv_right[col] adj_normalized = torch.sparse_coo_tensor( edge_index, edge_weight, size=(num_nodes, num_nodes) ).coalesce() return adj_normalized def _compute_distance(self, x: Tensor, x_smoothed: Tensor) -> Tensor: r"""Compute over-smoothing distance between original and smoothed features.""" if self.distance == 'cosine': # Cosine similarity as distance (1 - similarity) x_norm = F.normalize(x, p=2, dim=1) x_smoothed_norm = F.normalize(x_smoothed, p=2, dim=1) similarity = (x_norm * x_smoothed_norm).sum(dim=1) # Convert to distance (higher means farther from original) return 1.0 - similarity else: # euclidean return torch.norm(x - x_smoothed, p=2, dim=1) def _compute_smoothing_weights( self, features_list: List[Tensor], x_original: Tensor ) -> Tensor: r"""Compute node-adaptive smoothing weights using softmax normalization. Args: features_list: List of smoothed features at different steps x_original: Original node features Returns: Weight matrix of shape :obj:`(n_nodes, K+1)` """ n_nodes = x_original.size(0) K = len(features_list) - 1 # Compute distances for all smoothing steps distances = [] for x_k in features_list: dist = self._compute_distance(x_original, x_k) distances.append(dist) # Stack distances: (n_nodes, K+1) distances = torch.stack(distances, dim=1) # Apply softmax to get weights weights = F.softmax(distances, dim=1) return weights def _feature_smoothing( self, x: Tensor, edge_index: Adj, r: float = 0.5 ) -> Tensor: r"""Perform node-adaptive feature smoothing for a given r value.""" num_nodes = x.size(0) # Get normalized adjacency matrix adj_norm = self._normalize_adj(edge_index, num_nodes, r) # Compute smoothed features at different steps features_list = [x] # X^(0) x_current = x for _ in range(self.K): # X^(k) = A * X^(k-1) x_current = torch.sparse.mm(adj_norm, x_current) features_list.append(x_current) # Compute adaptive weights weights = self._compute_smoothing_weights(features_list, x) # Weighted combination output = torch.zeros_like(x) for k, x_k in enumerate(features_list): output += weights[:, k:k+1] * x_k return output
[docs] @torch.no_grad() def embed(self, x: Tensor, edge_index: Adj, **kwargs) -> Tensor: r"""Generate node embeddings using NAFS. Args: x (Tensor): Node features of shape :obj:`(num_nodes, num_features)`. edge_index (Tensor): Edge indices. Returns: Node embeddings of shape :obj:`(num_nodes, num_features * T)` if ensemble='concat', otherwise :obj:`(num_nodes, num_features)`. """ smoothed_features = [] # Feature smoothing with different r values for r in self.r_list: x_smoothed = self._feature_smoothing(x, edge_index, r) smoothed_features.append(x_smoothed) # Feature ensemble if self.ensemble == 'mean': output = torch.stack(smoothed_features).mean(dim=0) elif self.ensemble == 'max': output = torch.stack(smoothed_features).max(dim=0)[0] else: # concat output = torch.cat(smoothed_features, dim=1) return output
[docs] def forward(self, x: Tensor, edge_index: Adj, **kwargs) -> Tensor: r"""Alias for :meth:`embed`.""" return self.embed(x, edge_index, **kwargs)
def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'K={self.K}, ' f'r_list={self.r_list}, ' f'ensemble={self.ensemble})')