Scaling to Massive Graphs
This tutorial demonstrates how PyAGC scales to graphs with millions (or even billions) of nodes through mini-batch training, neighbor sampling, and intelligent memory management.
The Scalability Challenge
Traditional graph clustering methods face two major bottlenecks when dealing with large graphs:
Memory Bottleneck: Full-batch methods require loading the entire graph into GPU memory
Computation Bottleneck: Message passing over millions of nodes becomes prohibitively slow
Memory Requirements Example
Dataset |
Nodes |
Edges |
Features |
Est. GPU Memory |
|---|---|---|---|---|
Cora |
2.7K |
10.5K |
1,433 |
~50 MB |
233K |
23M |
602 |
~8 GB |
|
Products |
2.4M |
62M |
100 |
~32 GB (OOM!) |
Papers100M |
111M |
1.6B |
128 |
~500 GB (Impossible!) |
PyAGC’s Scalability Solutions
PyAGC provides three complementary strategies:
Mini-Batch Training: Process the graph in small batches instead of all at once
Neighbor Sampling: Limit the number of neighbors sampled during message passing
Automatic Fallback: Intelligently switch between GPU/CPU and full-batch/mini-batch modes
Mini-Batch Training
Basic Configuration
Enable mini-batch training in your configuration file:
# train.conf.yaml
mini_batch: true # Enable mini-batch mode
batch_size: 2048 # Nodes per batch
num_layers: 2 # GNN depth
fan_out: 10 # Neighbors to sample per layer
num_workers: 4 # Data loading workers
Using Mini-Batch in Code
All PyAGC models automatically support mini-batch training:
from pyagc.data import get_dataset
from pyagc.models import DGI
from pyagc.encoders import create_tuned_gnn
from torch_geometric.data import Data
# Load large graph
x, edge_index, y = get_dataset('Reddit', root='data')
data = Data(x=x, edge_index=edge_index)
# Create model
encoder = create_tuned_gnn(
gnn_type='sage', # GraphSAGE works well for mini-batch
in_channels=data.num_features,
hidden_channels=256,
num_layers=2
)
model = DGI(hidden_channels=256, encoder=encoder)
# The model will automatically use mini-batch training
# based on your configuration
Under the hood, PyAGC uses PyTorch Geometric’s NeighborLoader to create mini-batches:
from torch_geometric.loader import NeighborLoader
# Automatically created during training
train_loader = NeighborLoader(
data,
input_nodes=None, # Sample from all nodes
num_neighbors=[fan_out] * num_layers, # [10, 10] for 2 layers
batch_size=batch_size,
shuffle=True
)
Neighbor Sampling Strategies
Different sampling strategies for different scenarios:
# Uniform sampling (default)
fan_out: 10 # Sample 10 neighbors per layer
# Layer-wise sampling
fan_out: [15, 10, 5] # More neighbors in earlier layers
# Full neighborhood (for inference)
infer_fan_out: -1 # Sample all neighbors during inference
Inference Optimization
PyAGC provides separate configuration for inference to balance speed and accuracy:
Inference Batch Size
Use larger batches during inference since no gradients are computed:
batch_size: 1024 # Training batch size
infer_batch_size: 4096 # Inference batch size (can be larger)
The library automatically handles this:
# During training
model.train_batch(train_loader, optimizer, epoch)
# During inference - automatically uses larger batch size
z = model.infer_batch(inference_loader)
Full-Batch Inference Fallback
For better accuracy on medium-sized graphs, PyAGC can try full-batch inference first:
force_full_batch_inference: true # Try full-batch first
allow_cpu_fallback: true # Fallback to CPU if GPU OOMs
The automatic fallback strategy:
Try GPU full-batch - Best accuracy, fastest if it fits
Fallback to CPU full-batch - Better accuracy, slower but feasible
Fallback to GPU mini-batch - Good accuracy, memory-efficient
# Handled automatically by inference_embeddings()
z, inference_time = inference_embeddings(
model, data, conf, logger, device,
labeled_indices=labeled_indices # Optional: only infer subset
)
Memory Management
Handling OOM Errors
PyAGC automatically handles out-of-memory errors:
# Automatic batch size reduction on OOM
max_retries: 3
retry_count: 0
while retry_count < max_retries:
try:
# Try inference with current batch size
z = model.infer_batch(inference_loader)
break
except RuntimeError as e:
if 'out of memory' in str(e):
# Reduce batch size and retry
infer_batch_size = max(infer_batch_size // 2, 32)
logger.warning(f"OOM detected, reducing batch size to {infer_batch_size}")
torch.cuda.empty_cache()
retry_count += 1
Checkpoint Management
Save checkpoints during training to recover from interruptions:
save_every: 10 # Save every 10 epochs
save_every_batch: 1000 # Save every 1000 batches (for very large graphs)
from pyagc.utils import CheckpointManager
ckpt_manager = CheckpointManager('ckpts/reddit', 'seed0', logger)
# Resume training from checkpoint
if args.resume:
checkpoint = ckpt_manager.load_checkpoint(
model, optimizer, load_best=False, device=device
)
start_epoch = checkpoint['epoch'] + 1
Specialized Support for Large Graphs
Papers100M Dataset
PyAGC has special optimizations for billion-scale graphs:
# Automatically detects Papers100M and loads efficiently
x, edge_index, y, train_idx, valid_idx, test_idx, labeled_subgraph = get_dataset(
'Papers100M',
root='data',
return_splits=True # Returns labeled subset
)
# Only compute embeddings for labeled nodes (1.8M instead of 111M)
labeled_indices = torch.cat([train_idx, valid_idx, test_idx])
z, inference_time = inference_embeddings(
model, data, conf, logger, device,
labeled_indices=labeled_indices # Huge memory savings!
)
# Use labeled subgraph for structure metrics
struct_results = structure_metrics(
labeled_subgraph['edge_index'], # Much smaller graph
pred,
metrics=['Mod', 'Cond']
)
Model-Specific Optimizations
MAGI: Custom data loader for random walk sampling
from pyagc.models.magi import MAGINeighborLoader
train_loader = MAGINeighborLoader(
data,
num_neighbors=[10] * 2,
num_walks=20, # Random walks per node
walk_length=4, # Steps per walk
batch_size=2048
)
DAEGC: Multi-stage training with separate configurations
pretrain_epochs: 200
pretrain_batch_size: 2048
finetune_epochs: 100
finetune_batch_size: 4096 # Can be larger in finetuning
Practical Example: Scaling to Reddit
Complete workflow for the Reddit dataset (233K nodes, 23M edges):
Configuration File
# train.conf.yaml for Reddit
dataset: Reddit
# Model
gnn_type: sage
hidden_channels: 256
num_layers: 2
dropout: 0.0
# Training
mini_batch: true
batch_size: 2048
fan_out: 10
epochs: 200
lr: 0.001
# Inference
infer_batch_size: 8192
infer_fan_out: -1
force_full_batch_inference: false
allow_cpu_fallback: true
# Checkpointing
save_every: 20
Training Script
import torch
from pyagc.data import get_dataset
from pyagc.models import DGI
from pyagc.encoders import create_tuned_gnn
from pyagc.utils import get_training_config, CheckpointManager
from torch_geometric.data import Data
# Load configuration
conf = get_training_config('Reddit', config_path='train.conf.yaml')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load dataset
x, edge_index, y = get_dataset('Reddit', root='data')
data = Data(x=x, edge_index=edge_index)
print(f"Graph: {data.num_nodes:,} nodes, {data.num_edges:,} edges")
# Create model
encoder = create_tuned_gnn(
gnn_type='sage',
in_channels=data.num_features,
hidden_channels=256,
num_layers=2
)
model = DGI(hidden_channels=256, encoder=encoder).to(device)
# Setup checkpointing
ckpt_manager = CheckpointManager('ckpts/reddit', 'seed0')
# Training (automatically uses mini-batch)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 201):
# train_batch automatically creates NeighborLoader
loss = model.train_batch(train_loader, optimizer, epoch)
if epoch % 20 == 0:
ckpt_manager.save_checkpoint(model, optimizer, epoch, loss)
# Inference with automatic fallback
z, time = inference_embeddings(model, data, conf, device)
# Clustering
from pyagc.clusters import KMeansClusterHead
kmeans = KMeansClusterHead(n_clusters=41) # Reddit has 41 classes
pred = kmeans.fit_predict(z)
Expected Output
============================================================
System Information
============================================================
Dataset: Data(x=[232965, 602], edge_index=[2, 114615892])
Nodes: 232,965
Edges: 114,615,892
Training mode: Mini-batch
GNN type: SAGE
Total parameters: 1.234M
Device: cuda:0
GPU: NVIDIA A100-SXM4-40GB
============================================================
Training Mode: Mini-batch (NeighborLoader)
============================================================
Batch size: 2048
Fan-out: [10] × 2 layers
Epoch 001: 100%|██████████| 232965/232965 [00:42<00:00]
Epoch: 001 Loss: 0.8234
...
Training completed in 142.34s (avg 711ms/epoch)
============================================================
Inference Stage
============================================================
Attempting full-batch inference on cuda:0...
Full-batch inference failed on cuda:0 due to OOM
Using mini-batch inference on cuda:0...
Inference batch size: 8192
✓ Mini-batch inference completed in 8.21s
============================================================
Clustering Stage
============================================================
Run 1/5: NMI=45.23, ARI=42.18, ACC=68.91 [3.12s]
...
Final: NMI=45.67±0.31, ARI=42.45±0.28, ACC=69.12±0.34
Performance Comparison
Performance on different graph sizes:
Dataset |
Nodes |
Mode |
Training Time |
Peak Memory |
NMI |
|---|---|---|---|---|---|
Cora |
2.7K |
Full-batch |
12s |
0.2 GB |
68.5 |
Cora |
2.7K |
Mini-batch |
18s |
0.1 GB |
68.2 |
233K |
Full-batch |
OOM |
— |
— |
|
233K |
Mini-batch |
142s |
4.2 GB |
45.7 |
|
Products |
2.4M |
Full-batch |
OOM |
— |
— |
Products |
2.4M |
Mini-batch |
8.3min |
12 GB |
52.3 |
Best Practices
Choosing Batch Size
# Rule of thumb for batch_size:
# - Small graphs (< 10K nodes): Use full-batch
# - Medium graphs (10K - 1M nodes): batch_size = 1024-4096
# - Large graphs (> 1M nodes): batch_size = 4096-8192
if data.num_nodes < 10000:
conf['mini_batch'] = False
elif data.num_nodes < 1000000:
conf['batch_size'] = 2048
else:
conf['batch_size'] = 8192
Choosing Fan-out
# Trade-off between accuracy and speed:
# - High fan-out (20-50): Better accuracy, slower
# - Medium fan-out (10-15): Balanced
# - Low fan-out (5-10): Faster, may lose accuracy
# Deeper GNNs need smaller fan-out to control batch size
if num_layers == 2:
fan_out = 15
elif num_layers == 3:
fan_out = 10
else:
fan_out = 5
GNN Architecture Selection
# Best GNN types for mini-batch training:
# 1. GraphSAGE (sampling-based, designed for mini-batch)
# 2. GAT (attention weights computed locally)
# Avoid for mini-batch:
# - GCN with cached=True (expects full graph)
# - Models requiring global information
Troubleshooting
Common Issues and Solutions
Issue 1: Out of Memory During Training
RuntimeError: CUDA out of memory. Tried to allocate 2.34 GiB
Solution: Reduce batch size or fan-out
# Original configuration
batch_size: 4096
fan_out: 15
# Reduced configuration
batch_size: 1024 # Reduce by 4x
fan_out: 10 # Reduce neighbors
Issue 2: Out of Memory During Inference
RuntimeError: CUDA out of memory during inference
Solution: Enable automatic fallback and reduce inference batch size
force_full_batch_inference: false
allow_cpu_fallback: true
infer_batch_size: 2048 # Smaller than training batch_size
infer_fan_out: -1 # Keep full neighborhood for accuracy
Issue 3: Training is Too Slow
# Epoch takes 10+ minutes on medium-sized graph
Solution: Increase batch size and use multiple workers
batch_size: 8192 # Larger batches = fewer iterations
num_workers: 4 # Parallel data loading
fan_out: 10 # Reduce if still slow
Issue 4: Poor Clustering Performance in Mini-Batch Mode
# NMI drops from 68.5 (full-batch) to 45.2 (mini-batch)
Solution: Use full neighborhood during inference
# Training can use sampling
fan_out: 10
# Inference uses full neighborhood
infer_fan_out: -1
force_full_batch_inference: true # Try full-batch inference
Issue 5: DataLoader Deadlock with Multiple Workers
# Training hangs at first epoch with num_workers > 0
Solution: Set num_workers=0 or use proper multiprocessing setup
num_workers: 0 # Safe default, especially on CPU
Or use proper initialization:
if __name__ == '__main__':
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)
main()
Advanced Techniques
Gradient Accumulation for Larger Effective Batch Size
When GPU memory limits batch size but you want larger effective batches:
# Accumulate gradients over multiple mini-batches
accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps
for epoch in range(epochs):
optimizer.zero_grad()
total_loss = 0
for i, batch in enumerate(train_loader):
batch = batch.to(device)
loss_output = model.loss_batch(batch)
loss = loss_output.total / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
Mixed Precision Training
Use automatic mixed precision (AMP) to reduce memory usage:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for epoch in range(epochs):
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
# Forward pass with autocast
with autocast():
loss_output = model.loss_batch(batch)
loss = loss_output.total
# Backward pass with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
This can reduce memory usage by ~30-40% with minimal accuracy loss.
Distributed Training (Multi-GPU)
For extremely large graphs, distribute training across multiple GPUs:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch_geometric.loader import NeighborLoader
# Initialize distributed training
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
device = torch.device(f'cuda:{local_rank}')
# Wrap model with DDP
model = DGI(hidden_channels=256, encoder=encoder).to(device)
model = DDP(model, device_ids=[local_rank])
# Distributed data loading
train_loader = NeighborLoader(
data,
num_neighbors=[10] * 2,
batch_size=2048 // world_size, # Divide batch across GPUs
shuffle=True
)
# Training loop remains the same
for epoch in range(epochs):
avg_loss = model.module.train_batch(train_loader, optimizer, epoch)
Run with:
torchrun --nproc_per_node=4 train.py --dataset Reddit
Dynamic Sampling Strategies
Adjust sampling strategy during training:
def get_adaptive_fan_out(epoch, max_epochs):
"""Increase fan-out as training progresses for better accuracy."""
min_fan_out = 5
max_fan_out = 15
progress = epoch / max_epochs
return int(min_fan_out + (max_fan_out - min_fan_out) * progress)
for epoch in range(1, epochs + 1):
fan_out = get_adaptive_fan_out(epoch, epochs)
train_loader = NeighborLoader(
data,
num_neighbors=[fan_out] * num_layers,
batch_size=batch_size,
shuffle=True
)
avg_loss = model.train_batch(train_loader, optimizer, epoch)
Subgraph-Based Evaluation
For billion-scale graphs like Papers100M, evaluate on manageable subgraphs:
# Extract labeled subgraph
labeled_indices = torch.cat([train_idx, valid_idx, test_idx])
# Get subgraph containing only labeled nodes and their edges
from torch_geometric.utils import subgraph
edge_index_sub, _ = subgraph(
labeled_indices,
edge_index,
relabel_nodes=True,
num_nodes=data.num_nodes
)
# Only compute embeddings for labeled nodes
z_labeled, _ = inference_embeddings(
model, data, conf, logger, device,
labeled_indices=labeled_indices
)
# Evaluate on subgraph
struct_results = structure_metrics(
edge_index_sub,
pred,
metrics=['Mod', 'Cond']
)
Performance Profiling
Identify bottlenecks in your pipeline:
import torch.profiler as profiler
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
) as prof:
for epoch in range(1, 6): # Profile first 5 epochs
avg_loss = model.train_batch(train_loader, optimizer, epoch)
# Print profiling results
print(prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=10
))
# Export for visualization
prof.export_chrome_trace("trace.json")
# View at chrome://tracing
Example output:
--------------------------------- ------------ ------------ ------------
Name Self CPU % Self CUDA % CUDA total
--------------------------------- ------------ ------------ ------------
aten::matmul 15.2% 42.3% 1.234s
aten::index_select 8.4% 18.7% 543ms
Neighbor Sampling 12.1% 0.0% 352ms
--------------------------------- ------------ ------------ ------------
Real-World Case Studies
Case Study 1: Reddit (233K nodes)
Challenge: 23M edges cause OOM with full-batch training.
Solution:
gnn_type: sage
mini_batch: true
batch_size: 2048
fan_out: 10
num_layers: 2
infer_fan_out: -1 # Full neighborhood for inference
Results:
Training time: 142s (vs OOM with full-batch)
Peak memory: 4.2 GB (vs 32+ GB required for full-batch)
NMI: 45.7 (vs 46.2 with full-batch, only 1% accuracy loss)
Case Study 2: ogbn-products (2.4M nodes)
Challenge: Cannot fit even embeddings in GPU memory.
Solution:
gnn_type: sage
mini_batch: true
batch_size: 8192
fan_out: 10
num_layers: 2
# Use larger inference batch
infer_batch_size: 16384
force_full_batch_inference: false
allow_cpu_fallback: true
Results:
Training time: 8.3 min
Peak memory: 12 GB
NMI: 52.3
Inference: GPU mini-batch (CPU fallback not needed)
Case Study 3: ogbn-papers100M (111M nodes)
Challenge: Billion-scale graph, cannot store in single-GPU memory.
Solution:
gnn_type: sage
mini_batch: true
batch_size: 8192
fan_out: 5 # Lower fan-out for extreme scale
num_layers: 2
# Inference optimizations
infer_batch_size: 32768
infer_fan_out: 10 # Limited neighborhood for feasibility
Code:
# Load with labeled subset
x, edge_index, y, train_idx, valid_idx, test_idx, labeled_subgraph = \
get_dataset('Papers100M', root='data', return_splits=True)
labeled_indices = torch.cat([train_idx, valid_idx, test_idx])
# Only 1.8M nodes (1.6% of total) need embeddings
# Train on full graph, infer on labeled subset
z, _ = inference_embeddings(
model, data, conf, logger, device,
labeled_indices=labeled_indices # Huge memory savings
)
Results:
Training time: ~2 hours
Peak memory: 24 GB (single A100 GPU)
Evaluated on 1.8M labeled nodes instead of 111M
Structure metrics computed on labeled subgraph (23M edges vs 1.6B)
Summary and Recommendations
Quick Decision Guide
When to use full-batch:
Graph has < 10K nodes
GPU has sufficient memory (check with
dataset_stats.py)Need maximum accuracy
When to use mini-batch:
Graph has > 10K nodes
Getting OOM errors
Training on CPU
Want to scale to millions of nodes
Recommended configurations:
# Small graphs (< 10K nodes)
config = {
'mini_batch': False,
'cached': True
}
# Medium graphs (10K - 1M nodes)
config = {
'mini_batch': True,
'batch_size': 2048,
'fan_out': 10,
'infer_batch_size': 8192,
'infer_fan_out': -1,
'force_full_batch_inference': True
}
# Large graphs (1M - 10M nodes)
config = {
'mini_batch': True,
'batch_size': 4096,
'fan_out': 10,
'infer_batch_size': 16384,
'infer_fan_out': 15,
'force_full_batch_inference': False
}
# Extreme scale (> 10M nodes)
config = {
'mini_batch': True,
'batch_size': 8192,
'fan_out': 5,
'infer_batch_size': 32768,
'infer_fan_out': 10,
'allow_cpu_fallback': True
}
Key Takeaways
PyAGC scales from thousands to billions of nodes using neighbor sampling
Automatic fallback handles OOM gracefully without code changes
Separate inference config optimizes for speed and accuracy
Smart checkpointing enables resuming long training runs
Next Steps
Check benchmark scripts for billion-scale example
Review existing implementations in the models module