Scaling to Massive Graphs

This tutorial demonstrates how PyAGC scales to graphs with millions (or even billions) of nodes through mini-batch training, neighbor sampling, and intelligent memory management.

The Scalability Challenge

Traditional graph clustering methods face two major bottlenecks when dealing with large graphs:

  1. Memory Bottleneck: Full-batch methods require loading the entire graph into GPU memory

  2. Computation Bottleneck: Message passing over millions of nodes becomes prohibitively slow

Memory Requirements Example

Dataset

Nodes

Edges

Features

Est. GPU Memory

Cora

2.7K

10.5K

1,433

~50 MB

Reddit

233K

23M

602

~8 GB

Products

2.4M

62M

100

~32 GB (OOM!)

Papers100M

111M

1.6B

128

~500 GB (Impossible!)

PyAGC’s Scalability Solutions

PyAGC provides three complementary strategies:

  1. Mini-Batch Training: Process the graph in small batches instead of all at once

  2. Neighbor Sampling: Limit the number of neighbors sampled during message passing

  3. Automatic Fallback: Intelligently switch between GPU/CPU and full-batch/mini-batch modes

Mini-Batch Training

Basic Configuration

Enable mini-batch training in your configuration file:

# train.conf.yaml
mini_batch: true              # Enable mini-batch mode
batch_size: 2048              # Nodes per batch
num_layers: 2                 # GNN depth
fan_out: 10                   # Neighbors to sample per layer
num_workers: 4                # Data loading workers

Using Mini-Batch in Code

All PyAGC models automatically support mini-batch training:

from pyagc.data import get_dataset
from pyagc.models import DGI
from pyagc.encoders import create_tuned_gnn
from torch_geometric.data import Data

# Load large graph
x, edge_index, y = get_dataset('Reddit', root='data')
data = Data(x=x, edge_index=edge_index)

# Create model
encoder = create_tuned_gnn(
    gnn_type='sage',  # GraphSAGE works well for mini-batch
    in_channels=data.num_features,
    hidden_channels=256,
    num_layers=2
)
model = DGI(hidden_channels=256, encoder=encoder)

# The model will automatically use mini-batch training
# based on your configuration

Under the hood, PyAGC uses PyTorch Geometric’s NeighborLoader to create mini-batches:

from torch_geometric.loader import NeighborLoader

# Automatically created during training
train_loader = NeighborLoader(
    data,
    input_nodes=None,                # Sample from all nodes
    num_neighbors=[fan_out] * num_layers,  # [10, 10] for 2 layers
    batch_size=batch_size,
    shuffle=True
)

Neighbor Sampling Strategies

Different sampling strategies for different scenarios:

# Uniform sampling (default)
fan_out: 10                   # Sample 10 neighbors per layer

# Layer-wise sampling
fan_out: [15, 10, 5]          # More neighbors in earlier layers

# Full neighborhood (for inference)
infer_fan_out: -1             # Sample all neighbors during inference

Inference Optimization

PyAGC provides separate configuration for inference to balance speed and accuracy:

Inference Batch Size

Use larger batches during inference since no gradients are computed:

batch_size: 1024              # Training batch size
infer_batch_size: 4096        # Inference batch size (can be larger)

The library automatically handles this:

# During training
model.train_batch(train_loader, optimizer, epoch)

# During inference - automatically uses larger batch size
z = model.infer_batch(inference_loader)

Full-Batch Inference Fallback

For better accuracy on medium-sized graphs, PyAGC can try full-batch inference first:

force_full_batch_inference: true   # Try full-batch first
allow_cpu_fallback: true          # Fallback to CPU if GPU OOMs

The automatic fallback strategy:

  1. Try GPU full-batch - Best accuracy, fastest if it fits

  2. Fallback to CPU full-batch - Better accuracy, slower but feasible

  3. Fallback to GPU mini-batch - Good accuracy, memory-efficient

# Handled automatically by inference_embeddings()
z, inference_time = inference_embeddings(
    model, data, conf, logger, device,
    labeled_indices=labeled_indices  # Optional: only infer subset
)

Memory Management

Handling OOM Errors

PyAGC automatically handles out-of-memory errors:

# Automatic batch size reduction on OOM
max_retries: 3
retry_count: 0

while retry_count < max_retries:
    try:
        # Try inference with current batch size
        z = model.infer_batch(inference_loader)
        break
    except RuntimeError as e:
        if 'out of memory' in str(e):
            # Reduce batch size and retry
            infer_batch_size = max(infer_batch_size // 2, 32)
            logger.warning(f"OOM detected, reducing batch size to {infer_batch_size}")
            torch.cuda.empty_cache()
            retry_count += 1

Checkpoint Management

Save checkpoints during training to recover from interruptions:

save_every: 10                # Save every 10 epochs
save_every_batch: 1000        # Save every 1000 batches (for very large graphs)
from pyagc.utils import CheckpointManager

ckpt_manager = CheckpointManager('ckpts/reddit', 'seed0', logger)

# Resume training from checkpoint
if args.resume:
    checkpoint = ckpt_manager.load_checkpoint(
        model, optimizer, load_best=False, device=device
    )
    start_epoch = checkpoint['epoch'] + 1

Specialized Support for Large Graphs

Papers100M Dataset

PyAGC has special optimizations for billion-scale graphs:

# Automatically detects Papers100M and loads efficiently
x, edge_index, y, train_idx, valid_idx, test_idx, labeled_subgraph = get_dataset(
    'Papers100M',
    root='data',
    return_splits=True  # Returns labeled subset
)

# Only compute embeddings for labeled nodes (1.8M instead of 111M)
labeled_indices = torch.cat([train_idx, valid_idx, test_idx])

z, inference_time = inference_embeddings(
    model, data, conf, logger, device,
    labeled_indices=labeled_indices  # Huge memory savings!
)

# Use labeled subgraph for structure metrics
struct_results = structure_metrics(
    labeled_subgraph['edge_index'],  # Much smaller graph
    pred,
    metrics=['Mod', 'Cond']
)

Model-Specific Optimizations

MAGI: Custom data loader for random walk sampling

from pyagc.models.magi import MAGINeighborLoader

train_loader = MAGINeighborLoader(
    data,
    num_neighbors=[10] * 2,
    num_walks=20,           # Random walks per node
    walk_length=4,          # Steps per walk
    batch_size=2048
)

DAEGC: Multi-stage training with separate configurations

pretrain_epochs: 200
pretrain_batch_size: 2048

finetune_epochs: 100
finetune_batch_size: 4096  # Can be larger in finetuning

Practical Example: Scaling to Reddit

Complete workflow for the Reddit dataset (233K nodes, 23M edges):

Configuration File

# train.conf.yaml for Reddit
dataset: Reddit

# Model
gnn_type: sage
hidden_channels: 256
num_layers: 2
dropout: 0.0

# Training
mini_batch: true
batch_size: 2048
fan_out: 10
epochs: 200
lr: 0.001

# Inference
infer_batch_size: 8192
infer_fan_out: -1
force_full_batch_inference: false
allow_cpu_fallback: true

# Checkpointing
save_every: 20

Training Script

import torch
from pyagc.data import get_dataset
from pyagc.models import DGI
from pyagc.encoders import create_tuned_gnn
from pyagc.utils import get_training_config, CheckpointManager
from torch_geometric.data import Data

# Load configuration
conf = get_training_config('Reddit', config_path='train.conf.yaml')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load dataset
x, edge_index, y = get_dataset('Reddit', root='data')
data = Data(x=x, edge_index=edge_index)
print(f"Graph: {data.num_nodes:,} nodes, {data.num_edges:,} edges")

# Create model
encoder = create_tuned_gnn(
    gnn_type='sage',
    in_channels=data.num_features,
    hidden_channels=256,
    num_layers=2
)
model = DGI(hidden_channels=256, encoder=encoder).to(device)

# Setup checkpointing
ckpt_manager = CheckpointManager('ckpts/reddit', 'seed0')

# Training (automatically uses mini-batch)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 201):
    # train_batch automatically creates NeighborLoader
    loss = model.train_batch(train_loader, optimizer, epoch)

    if epoch % 20 == 0:
        ckpt_manager.save_checkpoint(model, optimizer, epoch, loss)

# Inference with automatic fallback
z, time = inference_embeddings(model, data, conf, device)

# Clustering
from pyagc.clusters import KMeansClusterHead
kmeans = KMeansClusterHead(n_clusters=41)  # Reddit has 41 classes
pred = kmeans.fit_predict(z)

Expected Output

============================================================
System Information
============================================================
Dataset: Data(x=[232965, 602], edge_index=[2, 114615892])
Nodes: 232,965
Edges: 114,615,892
Training mode: Mini-batch
GNN type: SAGE
Total parameters: 1.234M
Device: cuda:0
GPU: NVIDIA A100-SXM4-40GB

============================================================
Training Mode: Mini-batch (NeighborLoader)
============================================================
Batch size: 2048
Fan-out: [10] × 2 layers

Epoch 001: 100%|██████████| 232965/232965 [00:42<00:00]
Epoch: 001 Loss: 0.8234
...
Training completed in 142.34s (avg 711ms/epoch)

============================================================
Inference Stage
============================================================
Attempting full-batch inference on cuda:0...
Full-batch inference failed on cuda:0 due to OOM
Using mini-batch inference on cuda:0...
Inference batch size: 8192
✓ Mini-batch inference completed in 8.21s

============================================================
Clustering Stage
============================================================
Run 1/5: NMI=45.23, ARI=42.18, ACC=68.91 [3.12s]
...
Final: NMI=45.67±0.31, ARI=42.45±0.28, ACC=69.12±0.34

Performance Comparison

Performance on different graph sizes:

Dataset

Nodes

Mode

Training Time

Peak Memory

NMI

Cora

2.7K

Full-batch

12s

0.2 GB

68.5

Cora

2.7K

Mini-batch

18s

0.1 GB

68.2

Reddit

233K

Full-batch

OOM

Reddit

233K

Mini-batch

142s

4.2 GB

45.7

Products

2.4M

Full-batch

OOM

Products

2.4M

Mini-batch

8.3min

12 GB

52.3

Best Practices

Choosing Batch Size

# Rule of thumb for batch_size:
# - Small graphs (< 10K nodes): Use full-batch
# - Medium graphs (10K - 1M nodes): batch_size = 1024-4096
# - Large graphs (> 1M nodes): batch_size = 4096-8192

if data.num_nodes < 10000:
    conf['mini_batch'] = False
elif data.num_nodes < 1000000:
    conf['batch_size'] = 2048
else:
    conf['batch_size'] = 8192

Choosing Fan-out

# Trade-off between accuracy and speed:
# - High fan-out (20-50): Better accuracy, slower
# - Medium fan-out (10-15): Balanced
# - Low fan-out (5-10): Faster, may lose accuracy

# Deeper GNNs need smaller fan-out to control batch size
if num_layers == 2:
    fan_out = 15
elif num_layers == 3:
    fan_out = 10
else:
    fan_out = 5

GNN Architecture Selection

# Best GNN types for mini-batch training:
# 1. GraphSAGE (sampling-based, designed for mini-batch)
# 2. GAT (attention weights computed locally)

# Avoid for mini-batch:
# - GCN with cached=True (expects full graph)
# - Models requiring global information

Troubleshooting

Common Issues and Solutions

Issue 1: Out of Memory During Training

RuntimeError: CUDA out of memory. Tried to allocate 2.34 GiB

Solution: Reduce batch size or fan-out

# Original configuration
batch_size: 4096
fan_out: 15

# Reduced configuration
batch_size: 1024    # Reduce by 4x
fan_out: 10         # Reduce neighbors

Issue 2: Out of Memory During Inference

RuntimeError: CUDA out of memory during inference

Solution: Enable automatic fallback and reduce inference batch size

force_full_batch_inference: false
allow_cpu_fallback: true
infer_batch_size: 2048  # Smaller than training batch_size
infer_fan_out: -1       # Keep full neighborhood for accuracy

Issue 3: Training is Too Slow

# Epoch takes 10+ minutes on medium-sized graph

Solution: Increase batch size and use multiple workers

batch_size: 8192        # Larger batches = fewer iterations
num_workers: 4          # Parallel data loading
fan_out: 10             # Reduce if still slow

Issue 4: Poor Clustering Performance in Mini-Batch Mode

# NMI drops from 68.5 (full-batch) to 45.2 (mini-batch)

Solution: Use full neighborhood during inference

# Training can use sampling
fan_out: 10

# Inference uses full neighborhood
infer_fan_out: -1
force_full_batch_inference: true  # Try full-batch inference

Issue 5: DataLoader Deadlock with Multiple Workers

# Training hangs at first epoch with num_workers > 0

Solution: Set num_workers=0 or use proper multiprocessing setup

num_workers: 0  # Safe default, especially on CPU

Or use proper initialization:

if __name__ == '__main__':
    import torch.multiprocessing as mp
    mp.set_start_method('spawn', force=True)
    main()

Advanced Techniques

Gradient Accumulation for Larger Effective Batch Size

When GPU memory limits batch size but you want larger effective batches:

# Accumulate gradients over multiple mini-batches
accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps

for epoch in range(epochs):
    optimizer.zero_grad()
    total_loss = 0

    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        loss_output = model.loss_batch(batch)
        loss = loss_output.total / accumulation_steps
        loss.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()

Mixed Precision Training

Use automatic mixed precision (AMP) to reduce memory usage:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for epoch in range(epochs):
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Forward pass with autocast
        with autocast():
            loss_output = model.loss_batch(batch)
            loss = loss_output.total

        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

This can reduce memory usage by ~30-40% with minimal accuracy loss.

Distributed Training (Multi-GPU)

For extremely large graphs, distribute training across multiple GPUs:

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch_geometric.loader import NeighborLoader

# Initialize distributed training
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
device = torch.device(f'cuda:{local_rank}')

# Wrap model with DDP
model = DGI(hidden_channels=256, encoder=encoder).to(device)
model = DDP(model, device_ids=[local_rank])

# Distributed data loading
train_loader = NeighborLoader(
    data,
    num_neighbors=[10] * 2,
    batch_size=2048 // world_size,  # Divide batch across GPUs
    shuffle=True
)

# Training loop remains the same
for epoch in range(epochs):
    avg_loss = model.module.train_batch(train_loader, optimizer, epoch)

Run with:

torchrun --nproc_per_node=4 train.py --dataset Reddit

Dynamic Sampling Strategies

Adjust sampling strategy during training:

def get_adaptive_fan_out(epoch, max_epochs):
    """Increase fan-out as training progresses for better accuracy."""
    min_fan_out = 5
    max_fan_out = 15
    progress = epoch / max_epochs
    return int(min_fan_out + (max_fan_out - min_fan_out) * progress)

for epoch in range(1, epochs + 1):
    fan_out = get_adaptive_fan_out(epoch, epochs)

    train_loader = NeighborLoader(
        data,
        num_neighbors=[fan_out] * num_layers,
        batch_size=batch_size,
        shuffle=True
    )

    avg_loss = model.train_batch(train_loader, optimizer, epoch)

Subgraph-Based Evaluation

For billion-scale graphs like Papers100M, evaluate on manageable subgraphs:

# Extract labeled subgraph
labeled_indices = torch.cat([train_idx, valid_idx, test_idx])

# Get subgraph containing only labeled nodes and their edges
from torch_geometric.utils import subgraph

edge_index_sub, _ = subgraph(
    labeled_indices,
    edge_index,
    relabel_nodes=True,
    num_nodes=data.num_nodes
)

# Only compute embeddings for labeled nodes
z_labeled, _ = inference_embeddings(
    model, data, conf, logger, device,
    labeled_indices=labeled_indices
)

# Evaluate on subgraph
struct_results = structure_metrics(
    edge_index_sub,
    pred,
    metrics=['Mod', 'Cond']
)

Performance Profiling

Identify bottlenecks in your pipeline:

import torch.profiler as profiler

with profiler.profile(
    activities=[
        profiler.ProfilerActivity.CPU,
        profiler.ProfilerActivity.CUDA,
    ],
    record_shapes=True,
    profile_memory=True,
) as prof:
    for epoch in range(1, 6):  # Profile first 5 epochs
        avg_loss = model.train_batch(train_loader, optimizer, epoch)

# Print profiling results
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=10
))

# Export for visualization
prof.export_chrome_trace("trace.json")
# View at chrome://tracing

Example output:

---------------------------------  ------------  ------------  ------------
Name                                Self CPU %    Self CUDA %    CUDA total
---------------------------------  ------------  ------------  ------------
aten::matmul                            15.2%         42.3%       1.234s
aten::index_select                       8.4%         18.7%       543ms
Neighbor Sampling                       12.1%          0.0%       352ms
---------------------------------  ------------  ------------  ------------

Real-World Case Studies

Case Study 1: Reddit (233K nodes)

Challenge: 23M edges cause OOM with full-batch training.

Solution:

gnn_type: sage
mini_batch: true
batch_size: 2048
fan_out: 10
num_layers: 2
infer_fan_out: -1  # Full neighborhood for inference

Results:

  • Training time: 142s (vs OOM with full-batch)

  • Peak memory: 4.2 GB (vs 32+ GB required for full-batch)

  • NMI: 45.7 (vs 46.2 with full-batch, only 1% accuracy loss)

Case Study 2: ogbn-products (2.4M nodes)

Challenge: Cannot fit even embeddings in GPU memory.

Solution:

gnn_type: sage
mini_batch: true
batch_size: 8192
fan_out: 10
num_layers: 2

# Use larger inference batch
infer_batch_size: 16384
force_full_batch_inference: false
allow_cpu_fallback: true

Results:

  • Training time: 8.3 min

  • Peak memory: 12 GB

  • NMI: 52.3

  • Inference: GPU mini-batch (CPU fallback not needed)

Case Study 3: ogbn-papers100M (111M nodes)

Challenge: Billion-scale graph, cannot store in single-GPU memory.

Solution:

gnn_type: sage
mini_batch: true
batch_size: 8192
fan_out: 5  # Lower fan-out for extreme scale
num_layers: 2

# Inference optimizations
infer_batch_size: 32768
infer_fan_out: 10  # Limited neighborhood for feasibility

Code:

# Load with labeled subset
x, edge_index, y, train_idx, valid_idx, test_idx, labeled_subgraph = \
    get_dataset('Papers100M', root='data', return_splits=True)

labeled_indices = torch.cat([train_idx, valid_idx, test_idx])
# Only 1.8M nodes (1.6% of total) need embeddings

# Train on full graph, infer on labeled subset
z, _ = inference_embeddings(
    model, data, conf, logger, device,
    labeled_indices=labeled_indices  # Huge memory savings
)

Results:

  • Training time: ~2 hours

  • Peak memory: 24 GB (single A100 GPU)

  • Evaluated on 1.8M labeled nodes instead of 111M

  • Structure metrics computed on labeled subgraph (23M edges vs 1.6B)

Summary and Recommendations

Quick Decision Guide

When to use full-batch:

  • Graph has < 10K nodes

  • GPU has sufficient memory (check with dataset_stats.py)

  • Need maximum accuracy

When to use mini-batch:

  • Graph has > 10K nodes

  • Getting OOM errors

  • Training on CPU

  • Want to scale to millions of nodes

Recommended configurations:

# Small graphs (< 10K nodes)
config = {
    'mini_batch': False,
    'cached': True
}

# Medium graphs (10K - 1M nodes)
config = {
    'mini_batch': True,
    'batch_size': 2048,
    'fan_out': 10,
    'infer_batch_size': 8192,
    'infer_fan_out': -1,
    'force_full_batch_inference': True
}

# Large graphs (1M - 10M nodes)
config = {
    'mini_batch': True,
    'batch_size': 4096,
    'fan_out': 10,
    'infer_batch_size': 16384,
    'infer_fan_out': 15,
    'force_full_batch_inference': False
}

# Extreme scale (> 10M nodes)
config = {
    'mini_batch': True,
    'batch_size': 8192,
    'fan_out': 5,
    'infer_batch_size': 32768,
    'infer_fan_out': 10,
    'allow_cpu_fallback': True
}

Key Takeaways

  1. PyAGC scales from thousands to billions of nodes using neighbor sampling

  2. Automatic fallback handles OOM gracefully without code changes

  3. Separate inference config optimizes for speed and accuracy

  4. Smart checkpointing enables resuming long training runs

Next Steps