Source code for pyagc.data.datasets

import torch
import os
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid, CoraFull, Amazon, Coauthor, Flickr, Reddit2
from torch_geometric.utils import to_undirected, add_remaining_self_loops, subgraph
import psutil

from pyagc.data.graphland import GraphLandDataset


def get_available_ram_gb():
    """Get available RAM in GB."""
    return psutil.virtual_memory().total / (1024 ** 3)


[docs]def get_dataset(name: str, root: str, return_splits=False): r"""Loads a graph dataset by name and returns its features, edges, and labels. This function serves as a unified interface for loading a wide range of benchmark datasets used in graph learning, including both classical citation networks (e.g., Cora, PubMed) and large-scale Open Graph Benchmark (OGB) datasets (e.g., ogbn-arxiv, ogbn-products). It automatically normalizes node features, converts the graph to an undirected version. Optionally, it can also return predefined train/validation/test node splits for benchmarking purposes. Args: name (str): The name of the dataset to load. Supported options include: :obj:`['cora', 'citeseer', 'pubmed', 'corafull', 'photo', 'computers', 'cs', 'physics', 'flickr', 'reddit', 'reddit2', 'ogbn-arxiv', 'arxiv', 'ogbn-mag', 'mag', 'ogbn-products', 'products', 'ogbn-papers100M', 'papers100m', 'hm-categories', 'hm', 'pokec-regions', 'pokec', 'web-topics', 'webtopic']`. root (str): The root directory where the dataset should be stored. return_splits (bool, optional): If set to :obj:`True`, returns node-level split indices (train/valid/test) along with the features and edges. (default: :obj:`False`) Returns: (Tuple): Depending on :attr:`return_splits`: - If :obj:`False`, returns ``(x, edge_index, y)``: * ``x``: Node feature matrix :obj:`[num_nodes, num_features]` * ``edge_index``: Graph connectivity in COO format :obj:`[2, num_edges]` * ``y``: Node label vector :obj:`[num_nodes]` - If :obj:`True`, returns ``(x, edge_index, y, train_idx, valid_idx, test_idx)`` with additional index tensors for data splits. - For papers100M with return_splits=True, additionally returns: ``(x, edge_index, y, train_idx, valid_idx, test_idx, labeled_subgraph)`` where labeled_subgraph contains only edge_index and original_indices for structure metric computation. Raises: ValueError: If the provided dataset :attr:`name` is not recognized. """ # Normalize dataset name for case-insensitive matching. name = name.lower() # Special handling for papers100M dataset if name in ['ogbn-papers100M', 'papers100m']: return _load_papers100m(root, return_splits) if name in ['cora', 'citeseer', 'pubmed']: dataset = Planetoid(root=root, name=name, transform=T.NormalizeFeatures()) elif name in ['corafull']: dataset = CoraFull(f'{root}/{name}', transform=T.NormalizeFeatures()) elif name in ['photo', 'computers']: dataset = Amazon(root=root, name=name, transform=T.NormalizeFeatures()) elif name in ['cs', 'physics']: dataset = Coauthor(root=root, name=name, transform=T.NormalizeFeatures()) elif name in ['flickr']: dataset = Flickr(f'{root}/{name}', transform=T.NormalizeFeatures()) elif name in ['reddit', 'reddit2']: name = 'reddit' dataset = Reddit2(root=f'{root}/{name}', transform=T.NormalizeFeatures()) elif name in ['ogbn-arxiv', 'arxiv']: dataset = PygNodePropPredDataset(root=root, name='ogbn-arxiv') elif name in ['mag']: dataset = PygNodePropPredDataset(root=root, name='ogbn-mag') rel_data = dataset[0] # We are only interested in paper <-> paper relations. data = Data( x=rel_data.x_dict['paper'], edge_index=rel_data.edge_index_dict[('paper', 'cites', 'paper')], y=rel_data.y_dict['paper']) dataset._data = data elif name in ['ogbn-products', 'products']: dataset = PygNodePropPredDataset(root=root, name='ogbn-products') elif name in ['hm-categories', 'hm']: dataset = GraphLandDataset(root=root, name='hm-categories', split='TH') elif name in ['pokec-regions', 'pokec']: dataset = GraphLandDataset(root=root, name='pokec-regions', split='TH') elif name in ['web-topics', 'webtopic']: dataset = GraphLandDataset(root=root, name='web-topics', split='TH') else: raise ValueError(f'Unknown dataset: {name}') # Retrieve data object and apply structural normalization. data = dataset[0] data.edge_index = to_undirected(data.edge_index) data.y = data.y.squeeze() # Return with or without split indices. if return_splits: if isinstance(dataset, PygNodePropPredDataset): split_idx = dataset.get_idx_split() if name in ['mag']: # Handle heterogeneous dataset structure. train_idx, valid_idx, test_idx = ( split_idx['train']['paper'], split_idx['valid']['paper'], split_idx['test']['paper']) else: train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] else: # For standard datasets using boolean masks. train_idx = data.train_mask.nonzero(as_tuple=False).view(-1) valid_idx = data.val_mask.nonzero(as_tuple=False).view(-1) test_idx = data.test_mask.nonzero(as_tuple=False).view(-1) return data.x, data.edge_index, data.y, train_idx, valid_idx, test_idx else: return data.x, data.edge_index, data.y
def _load_papers100m(root: str, return_splits: bool): """Special handler for papers100M dataset with preprocessing and caching.""" # Define paths for preprocessed data processed_dir = os.path.join(root, 'ogbn_papers100M', 'processed_undirected') os.makedirs(processed_dir, exist_ok=True) processed_data_path = os.path.join(processed_dir, 'data.pt') processed_splits_path = os.path.join(processed_dir, 'splits.pt') processed_subgraph_path = os.path.join(processed_dir, 'labeled_subgraph.pt') # Check if preprocessed data exists if os.path.exists(processed_data_path): print(f"Loading preprocessed papers100M dataset from {processed_data_path}") cached_data = torch.load(processed_data_path) x = cached_data['x'] edge_index = cached_data['edge_index'] y = cached_data['y'] if return_splits: if os.path.exists(processed_splits_path): splits = torch.load(processed_splits_path) train_idx = splits['train'] valid_idx = splits['valid'] test_idx = splits['test'] else: print("Warning: Preprocessed splits not found, loading from dataset...") dataset = PygNodePropPredDataset(root=root, name='ogbn-papers100M') split_idx = dataset.get_idx_split() train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] # Save splits for future use torch.save({ 'train': train_idx, 'valid': valid_idx, 'test': test_idx }, processed_splits_path) # Load or create labeled subgraph (only structure, no features) if os.path.exists(processed_subgraph_path): print(f"Loading preprocessed labeled subgraph from {processed_subgraph_path}") labeled_subgraph = torch.load(processed_subgraph_path) else: print("Warning: Preprocessed labeled subgraph not found.") labeled_subgraph = None return x, edge_index, y, train_idx, valid_idx, test_idx, labeled_subgraph else: return x, edge_index, y # First time loading - check RAM and preprocess print("First time loading papers100M dataset...") available_ram = get_available_ram_gb() required_ram = 400 # GB print(f"Available RAM: {available_ram:.2f} GB") print(f"Estimated required RAM: {required_ram} GB") if available_ram < required_ram: raise MemoryError( f"Insufficient RAM for processing papers100M dataset. " f"Available: {available_ram:.2f} GB, Required: ~{required_ram} GB. " f"Please run this preprocessing step on a machine with sufficient memory." ) print("Loading original dataset...") dataset = PygNodePropPredDataset(root=root, name='ogbn-papers100M') data = dataset[0] print("Converting to undirected graph (this may take a while)...") edge_index_undirected = to_undirected(data.edge_index) print("Preparing data for saving...") y = data.y.squeeze() # Save preprocessed data print(f"Saving preprocessed data to {processed_data_path}") torch.save({ 'x': data.x, 'edge_index': edge_index_undirected, 'y': y }, processed_data_path) # Save splits and create labeled subgraph if return_splits: print(f"Saving splits to {processed_splits_path}") split_idx = dataset.get_idx_split() train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] torch.save({ 'train': train_idx, 'valid': valid_idx, 'test': test_idx }, processed_splits_path) # Create and save labeled subgraph (structure only, no features) print("Creating labeled subgraph for structure metrics (this may take a while)...") labeled_nodes = torch.cat([train_idx, valid_idx, test_idx]) print(f"Number of labeled nodes: {labeled_nodes.shape[0]:,}") # Extract subgraph edges for labeled nodes sub_edge_index = subgraph( labeled_nodes, edge_index_undirected, relabel_nodes=True, num_nodes=data.num_nodes )[0] # Create mapping from original indices to subgraph indices node_mapping = torch.full((data.num_nodes,), -1, dtype=torch.long) node_mapping[labeled_nodes] = torch.arange(labeled_nodes.shape[0]) # Create lightweight subgraph data object (no features, no labels) labeled_subgraph = { 'edge_index': sub_edge_index, 'num_nodes': labeled_nodes.shape[0], 'original_indices': labeled_nodes # Keep track of original node indices for mapping predictions } print(f"Labeled subgraph: {labeled_subgraph['num_nodes']:,} nodes, {sub_edge_index.shape[1]:,} edges") print(f"Saving labeled subgraph to {processed_subgraph_path}") torch.save(labeled_subgraph, processed_subgraph_path) print("Preprocessing complete!") if return_splits: return data.x, edge_index_undirected, y, train_idx, valid_idx, test_idx, labeled_subgraph else: return data.x, edge_index_undirected, y