import math
from typing import Optional, Tuple, Literal, Union, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from torch_geometric.index import index2ptr
from torch_geometric.nn.inits import reset
from torch_geometric.typing import WITH_PYG_LIB, WITH_TORCH_CLUSTER
from torch_geometric.utils import degree, add_remaining_self_loops, sort_edge_index
from torch_geometric.utils.num_nodes import maybe_num_nodes
from pyagc.models.base import TrainableModel, LossOutput
def ppr_diffusion_weights(k: int, alpha: float = 0.2) -> Tensor:
r"""
Computes PPR (Personalized PageRank) diffusion weights.
.. math::
w_i = \alpha(1-\alpha)^i
Args:
k (int): Number of hops.
alpha (float, optional): Teleport probability. (default: :obj:`0.2`)
Returns:
Tensor of shape :obj:`(k+1,)` containing weights.
"""
weights = torch.tensor([alpha * ((1 - alpha) ** i) for i in range(k + 1)])
return weights
def heat_diffusion_weights(k: int, t: float = 5.0) -> Tensor:
r"""
Computes heat kernel diffusion weights.
.. math::
w_i = \frac{e^{-t} \cdot t^i}{i!}
Args:
k (int): Number of hops.
t (float, optional): Diffusion time. (default: :obj:`5.0`)
Returns:
Tensor of shape :obj:`(k+1,)` containing weights.
"""
weights = torch.tensor([
(math.exp(-t) * (t ** i)) / math.factorial(i)
for i in range(k + 1)
])
return weights
def compute_normalized_adjacency(
edge_index: Tensor,
num_nodes: int,
add_self_loops: bool = True
) -> Tensor:
r"""
Computes the symmetric normalized adjacency matrix in COO format.
.. math::
\tilde{A} = D^{-1/2} A D^{-1/2}
where :math:`A` is the adjacency matrix with optional self-loops.
Args:
edge_index (Tensor): Edge indices of shape :obj:`(2, num_edges)`.
num_nodes (int): Number of nodes in the graph.
add_self_loops (bool, optional): Whether to add self-loops.
(default: :obj:`True`)
Returns:
Sparse normalized adjacency matrix.
"""
if add_self_loops:
edge_index, _ = add_remaining_self_loops(edge_index, num_nodes=num_nodes)
row, col = edge_index
deg = degree(row, num_nodes, dtype=torch.float)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0
edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return torch.sparse_coo_tensor(
edge_index,
edge_weight,
size=(num_nodes, num_nodes)
).to(edge_index.device)
def compute_diffusion_matrix(
normalized_adj: Tensor,
x: Tensor,
k: int = 2,
method: Literal['ppr', 'heat', 'custom'] = 'ppr',
coefs: Optional[Tensor] = None,
**kwargs
) -> Tensor:
r"""
Computes the :math:`k`-hop diffusion matrix:
.. math::
S_k X = \sum_{i=0}^{k} \alpha_i \tilde{A}^i X
where :math:`\tilde{A}` is the normalized adjacency matrix.
Args:
normalized_adj (Tensor): Sparse normalized adjacency matrix.
x (Tensor): Node features of shape :obj:`(num_nodes, num_features)`.
k (int, optional): Number of hops. (default: :obj:`2`)
method (str, optional): Diffusion method, one of :obj:`['ppr', 'heat', 'custom']`.
(default: :obj:`'ppr'`)
coefs (Tensor, optional): Custom diffusion coefficients of shape :obj:`(k+1,)`.
Required if method is :obj:`'custom'`. (default: :obj:`None`)
**kwargs: Additional arguments for PPR (alpha) or heat (t) methods.
Returns:
Diffusion features of shape :obj:`(num_nodes, num_features)`.
Example:
>>> # PPR diffusion
>>> SX = compute_diffusion_matrix(normalized_adj, x, k=2, method='ppr', alpha=0.2)
>>>
>>> # Heat diffusion
>>> SX = compute_diffusion_matrix(normalized_adj, x, k=3, method='heat', t=5.0)
>>>
>>> # Custom weights
>>> weights = torch.tensor([0.5, 0.3, 0.2])
>>> SX = compute_diffusion_matrix(normalized_adj, x, k=2, method='custom', coefs=weights)
"""
# Determine diffusion weights
if method == 'ppr':
alpha = kwargs.get('alpha', 0.2)
weights = ppr_diffusion_weights(k, alpha)
elif method == 'heat':
t = kwargs.get('t', 5.0)
weights = heat_diffusion_weights(k, t)
elif method == 'custom':
if coefs is None:
raise ValueError("Must provide 'coefs' when method='custom'")
weights = coefs
if weights.size(0) != k + 1:
raise ValueError(f"coefs must have length {k + 1}, got {weights.size(0)}")
else:
raise ValueError(f"Unknown diffusion method: {method}. Choose from ['ppr', 'heat', 'custom']")
weights = weights.to(x.device)
# Compute diffusion
result = weights[0] * x
current = x
for i in range(1, k + 1):
current = torch.sparse.mm(normalized_adj, current)
result = result + weights[i] * current
return result
class SmallS3GCEncoder(nn.Module):
r"""
Encoder for small-scale graphs (e.g., Cora).
Uses simple linear transformations without hidden layers.
Architecture:
:math:`\bar{X} = \tilde{A}X\Theta_1 + \S_kX\Theta_2 + I`
"""
def __init__(self, in_channels: int, hidden_channels: int, num_nodes: int):
super().__init__()
self.w1 = nn.Linear(in_channels, hidden_channels, bias=True)
self.w2 = nn.Linear(in_channels, hidden_channels, bias=True)
self.iden = nn.Parameter(torch.randn(num_nodes, hidden_channels, dtype=torch.float))
self.reset_parameters()
def reset_parameters(self):
self.w1.bias.data.fill_(0.0)
self.w2.bias.data.fill_(0.0)
nn.init.normal_(self.iden)
def forward(self, AX: Tensor, SX: Tensor, indices: Optional[Tensor] = None) -> Tensor:
if indices is not None:
AX = AX[indices]
SX = SX[indices]
iden = self.iden[indices]
else:
iden = self.iden
return F.normalize(
self.w1(AX) + self.w2(SX) + iden,
p=2, dim=1
)
class MediumS3GCEncoder(nn.Module):
r"""
Encoder for medium-scale graphs (e.g., ogbn-arxiv).
Uses 2-layer MLPs for better expressiveness.
Architecture:
:math:`\bar{X} = \text{PReLU}(W_2 \cdot \text{PReLU}(W_1 \tilde{A}X)) +
\text{PReLU}(W_4 \cdot \text{PReLU}(W_3 S_kX)) + I`
"""
def __init__(self, in_channels: int, hidden_channels: int, num_nodes: int):
super().__init__()
self.w1 = nn.Linear(in_channels, hidden_channels, bias=True)
self.w2 = nn.Linear(hidden_channels, hidden_channels, bias=True)
self.w3 = nn.Linear(in_channels, hidden_channels, bias=True)
self.w4 = nn.Linear(hidden_channels, hidden_channels, bias=True)
self.prelu1 = nn.PReLU(hidden_channels)
self.prelu2 = nn.PReLU(hidden_channels)
self.iden = nn.Parameter(torch.randn(num_nodes, hidden_channels, dtype=torch.float))
self.reset_parameters()
def reset_parameters(self):
self.w1.bias.data.fill_(0.0)
self.w2.bias.data.fill_(0.0)
self.w3.bias.data.fill_(0.0)
self.w4.bias.data.fill_(0.0)
nn.init.normal_(self.iden)
def forward(self, AX: Tensor, SX: Tensor, indices: Optional[Tensor] = None) -> Tensor:
if indices is not None:
AX = AX[indices]
SX = SX[indices]
iden = self.iden[indices]
else:
iden = self.iden
return F.normalize(
self.w2(self.prelu1(self.w1(AX))) +
self.w4(self.prelu2(self.w3(SX))) +
iden,
p=2, dim=1
)
class LargeS3GCEncoder(nn.Module):
r"""
Encoder for large-scale graphs (e.g., ogbn-papers100M).
Uses 2-layer MLPs for better expressiveness.
Unlike Medium encoder, this doesn't store node embeddings internally -
they are managed externally by the S3GC model for memory efficiency.
Architecture:
:math:`\bar{X} = \text{PReLU}(W_2 \cdot \text{PReLU}(W_1 \tilde{A}X)) +
\text{PReLU}(W_4 \cdot \text{PReLU}(W_3 S_kX)) + I`
where I (identity embeddings) are passed as an argument rather than stored.
"""
def __init__(self, in_channels: int, hidden_channels: int):
super().__init__()
self.w1 = nn.Linear(in_channels, hidden_channels, bias=True)
self.w2 = nn.Linear(hidden_channels, hidden_channels, bias=True)
self.w3 = nn.Linear(in_channels, hidden_channels, bias=True)
self.w4 = nn.Linear(hidden_channels, hidden_channels, bias=True)
self.prelu1 = nn.PReLU(hidden_channels)
self.prelu2 = nn.PReLU(hidden_channels)
self.reset_parameters()
def reset_parameters(self):
self.w1.bias.data.fill_(0.0)
self.w2.bias.data.fill_(0.0)
self.w3.bias.data.fill_(0.0)
self.w4.bias.data.fill_(0.0)
def forward(self, AX: Tensor, SX: Tensor, iden: Tensor) -> Tensor:
r"""
Forward pass with explicit iden parameter.
Args:
AX (Tensor): Pre-multiplied features :math:`\tilde{A}X` of shape
:obj:`(batch_size, in_channels)`.
SX (Tensor): Diffusion features :math:`S_kX` of shape
:obj:`(batch_size, in_channels)`.
iden (Tensor): Learnable node embeddings of shape
:obj:`(batch_size, hidden_channels)`.
Returns:
Normalized embeddings of shape :obj:`(batch_size, hidden_channels)`.
"""
return F.normalize(
self.w2(self.prelu1(self.w1(AX))) +
self.w4(self.prelu2(self.w3(SX))) +
iden,
p=2, dim=1
)
[docs]class S3GC(TrainableModel):
r"""
The S3GC (Scalable Self-Supervised Graph Clustering) model from the
`"S3GC: Scalable Self-Supervised Graph Clustering"
<https://openreview.net/forum?id=ldl2V3vLZ5>`_ paper (Devvrit et al., NeurIPS 2022).
S3GC uses a simple GCN-based encoder combined with contrastive learning
to learn clusterable node representations. The architecture adapts based
on graph scale:
- **Small** (< 50K nodes): Simple linear encoder
- **Medium** (50K - 3M nodes): 2-layer MLP encoder
- **Large** (> 3M nodes): 2-layer MLP with memory-efficient embeddings
The encoder combines:
1. Direct attribute transformation: :math:`f(\tilde{A}X)`
2. Diffusion-based transformation: :math:`g(S_kX)`
3. Learnable node embeddings: :math:`I`
Training uses SimCLR-style contrastive loss where nodes sampled via
random walks are positives, and randomly sampled nodes are negatives.
Args:
edge_index (Tensor): Edge indices of the graph.
num_nodes (int): Number of nodes in the graph.
in_channels (int): Input feature dimension.
hidden_channels (int): Output embedding dimension.
walk_length (int): Length of random walks for positive sampling.
context_size (int): Context size for random walk (should be :math:`\leq` walk_length).
walks_per_node (int, optional): Number of random walks per node. (default: :obj:`1`)
num_negative_samples (int, optional): Number of negative samples per positive.
(default: :obj:`1`)
p (float, optional): Return parameter for random walk. (default: :obj:`1.0`)
q (float, optional): In-out parameter for random walk. (default: :obj:`1.0`)
scale (str, optional): Graph scale, one of :obj:`['small', 'medium', 'large', 'auto']`.
If :obj:`'auto'`, automatically determined by num_nodes. (default: :obj:`'auto'`)
Example:
>>> from pyagc.models.s3gc import S3GC, precompute_features
>>> from pyagc.data import get_dataset
>>>
>>> # Load data
>>> x, edge_index, y = get_dataset('Cora', root='./data')
>>>
>>> # Precompute features
>>> AX, SX = precompute_features(x, edge_index, x.size(0), method='ppr')
>>>
>>> # Create model (automatically detects 'small' scale)
>>> model = S3GC(
... edge_index=edge_index,
... num_nodes=x.size(0),
... in_channels=x.size(1),
... hidden_channels=256,
... walk_length=3,
... context_size=3
... )
>>>
>>> # Set precomputed features
>>> model.set_precomputed_features(AX, SX)
>>>
>>> # Train
>>> loader = model.loader(batch_size=2708, shuffle=True)
>>> optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
>>> for epoch in range(1, 201):
... loss = model.train_epoch(loader, optimizer, epoch)
"""
def __init__(
self,
edge_index: Tensor,
num_nodes: int,
in_channels: int,
hidden_channels: int,
walk_length: int,
context_size: int,
walks_per_node: int = 1,
num_negative_samples: int = 1,
p: float = 1.0,
q: float = 1.0,
scale: Literal['small', 'medium', 'large', 'auto'] = 'auto',
keep_embeddings_on_cpu: bool = False,
):
super().__init__()
if WITH_PYG_LIB and p == 1.0 and q == 1.0:
self.random_walk_fn = torch.ops.pyg.random_walk
elif WITH_TORCH_CLUSTER:
self.random_walk_fn = torch.ops.torch_cluster.random_walk
else:
if p == 1.0 and q == 1.0:
raise ImportError(f"'{self.__class__.__name__}' "
f"requires either the 'pyg-lib' or "
f"'torch-cluster' package")
else:
raise ImportError(f"'{self.__class__.__name__}' "
f"requires the 'torch-cluster' package")
if walk_length < context_size:
raise ValueError(
f"walk_length ({walk_length}) must be >= context_size ({context_size})"
)
self.num_nodes = maybe_num_nodes(edge_index, num_nodes)
# Determine scale
if scale == 'auto':
if num_nodes < 50_000:
scale = 'small'
elif num_nodes < 3_000_000:
scale = 'medium'
else:
scale = 'large'
self.scale = scale
self.keep_embeddings_on_cpu = keep_embeddings_on_cpu
# Create appropriate encoder
if scale == 'small':
self.encoder = SmallS3GCEncoder(in_channels, hidden_channels, num_nodes)
elif scale == 'medium':
self.encoder = MediumS3GCEncoder(in_channels, hidden_channels, num_nodes)
else: # large
self.encoder = LargeS3GCEncoder(in_channels, hidden_channels)
# Embeddings are managed separately for memory efficiency
self.iden_embedding = nn.Embedding(num_nodes, hidden_channels, sparse=True)
# Convert to CSR format for efficient random walk
row, col = sort_edge_index(edge_index, num_nodes=self.num_nodes).cpu()
self.rowptr, self.col = index2ptr(row, self.num_nodes), col
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.walk_length = walk_length - 1 # Adjusted for implementation
self.context_size = context_size
self.walks_per_node = walks_per_node
self.num_negative_samples = num_negative_samples
self.p = p
self.q = q
# Store precomputed matrices (will be set externally)
self.register_buffer('_AX', torch.empty(0), persistent=False)
self.register_buffer('_SX', torch.empty(0), persistent=False)
self._features_on_cpu = False
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters."""
reset(self.encoder)
if self.scale == 'large':
nn.init.normal_(self.iden_embedding.weight)
[docs] def set_precomputed_features(self, AX: Tensor, SX: Tensor, keep_on_cpu: bool = False):
r"""
Sets precomputed :math:`\tilde{A}X` and :math:`S_kX` matrices.
This should be called before training to avoid recomputing these
matrices in every forward pass.
Args:
AX (Tensor): Precomputed :math:`\tilde{A}X` of shape :obj:`(num_nodes, in_channels)`.
SX (Tensor): Precomputed :math:`S_kX` of shape :obj:`(num_nodes, in_channels)`.
keep_on_cpu (bool): If True, keep features on CPU to save GPU memory
"""
if keep_on_cpu:
# Keep on CPU, will transfer batches during training
self._AX = AX.cpu()
self._SX = SX.cpu()
self._features_on_cpu = True
else:
device = next(self.encoder.parameters()).device
self._AX = AX.to(device)
self._SX = SX.to(device)
self._features_on_cpu = False
[docs] def embed(self, indices: Optional[Tensor] = None) -> Tensor:
r"""
Returns node embeddings.
Args:
indices (Tensor, optional): Node indices. If :obj:`None`, returns
embeddings for all nodes. (default: :obj:`None`)
Returns:
Node embeddings of shape :obj:`(num_nodes, hidden_channels)` or
:obj:`(len(indices), hidden_channels)`.
"""
if self._AX.numel() == 0 or self._SX.numel() == 0:
raise RuntimeError(
"Must call `set_precomputed_features()` before `embed()`."
)
if self.scale == 'large':
# For large graphs, use embedding layer
if indices is None:
indices = torch.arange(self.num_nodes, device=self._AX.device)
# Ensure indices are on correct device
if self._features_on_cpu:
indices_cpu = indices.cpu() if indices.device.type != 'cpu' else indices
AX = self._AX[indices_cpu]
SX = self._SX[indices_cpu]
# Move features to encoder device
encoder_device = next(self.encoder.parameters()).device
AX = AX.to(encoder_device)
SX = SX.to(encoder_device)
else:
AX = self._AX[indices]
SX = self._SX[indices]
# Handle embeddings - load to GPU on demand
if self.keep_embeddings_on_cpu:
# Embeddings are on CPU, load batch to GPU
indices_cpu = indices.cpu() if indices.device.type != 'cpu' else indices
iden = self.iden_embedding(indices_cpu)
# Move to same device as encoder
encoder_device = next(self.encoder.parameters()).device
iden = iden.to(encoder_device)
else:
# Embeddings are on GPU
indices_embed = indices.to(self.iden_embedding.weight.device)
iden = self.iden_embedding(indices_embed)
return self.encoder(AX, SX, iden)
else:
# For small/medium graphs, iden is part of encoder
if self._features_on_cpu and indices is not None:
device = next(self.encoder.parameters()).device
AX = self._AX[indices].to(device)
SX = self._SX[indices].to(device)
return self.encoder(AX, SX, indices)
else:
return self.encoder(self._AX, self._SX, indices)
[docs] @torch.jit.export
def pos_sample(self, batch: Tensor) -> Tensor:
r"""Samples positive nodes via biased random walks."""
batch = batch.repeat(self.walks_per_node)
rw = self.random_walk_fn(
self.rowptr.cpu(),
self.col.cpu(),
batch.cpu(),
self.walk_length,
self.p,
self.q
)
if not isinstance(rw, Tensor):
rw = rw[0]
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)
[docs] @torch.jit.export
def neg_sample(self, batch: Tensor) -> Tensor:
r"""Samples negative nodes randomly or cluster-aware."""
batch = batch.repeat(self.walks_per_node * self.num_negative_samples)
# Random negative sampling
rw = torch.randint(
self.num_nodes,
(batch.size(0), self.walk_length * self.num_negative_samples),
device=batch.device
)
rw = torch.cat([batch.view(-1, 1), rw], dim=-1)
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)
[docs] @torch.jit.export
def sample(self, batch: Union[List[int], Tensor]) -> Tuple[Tensor, Tensor]:
r"""Samples positive and negative random walks for a batch."""
if not isinstance(batch, Tensor):
batch = torch.tensor(batch)
return self.pos_sample(batch), self.neg_sample(batch)
[docs] def loader(self, **kwargs) -> DataLoader:
r"""
Creates a DataLoader for S3GC training.
Returns:
DataLoader that samples nodes and generates positive/negative walks.
"""
return DataLoader(range(self.num_nodes), collate_fn=self.sample, **kwargs)
[docs] def loss(self, pos_rw: Tensor, neg_rw: Tensor, node_mapping: Tensor) -> LossOutput:
r"""
Computes the SimCLR-style contrastive loss.
Args:
pos_rw (Tensor): Positive random walks of shape :obj:`(num_walks, context_size)`.
neg_rw (Tensor): Negative random walks of shape :obj:`(num_walks, context_size)`.
node_mapping (Tensor): Mapping from global node IDs to batch-local IDs.
Returns:
LossOutput containing total loss and individual components.
"""
# Map global node IDs to batch-local IDs
pos_rw_mapped = F.embedding(pos_rw.view(-1), node_mapping.view(-1, 1)).view(pos_rw.size())
neg_rw_mapped = F.embedding(neg_rw.view(-1), node_mapping.view(-1, 1)).view(neg_rw.size())
# Get unique nodes and compute their embeddings
unique = torch.unique(torch.cat((pos_rw, neg_rw), dim=-1))
# For CPU embeddings, need to handle device transfer
if self.scale == 'large' and self.keep_embeddings_on_cpu:
# Keep unique on original device for consistency
device = unique.device
embeddings = self.embed(indices=unique) # This will handle CPU->GPU transfer
embeddings = embeddings.to(device) # Ensure on correct device
else:
embeddings = self.embed(indices=unique)
# Rest of loss computation remains the same
start_pos, rest_pos = pos_rw_mapped[:, 0], pos_rw_mapped[:, 1:].contiguous()
h_start_pos = F.embedding(start_pos, embeddings).view(
pos_rw_mapped.size(0), 1, self.hidden_channels
)
h_rest_pos = F.embedding(rest_pos.view(-1), embeddings).view(
pos_rw_mapped.size(0), -1, self.hidden_channels
)
out_pos = (h_start_pos * h_rest_pos).sum(dim=-1)
pos_loss = torch.logsumexp(out_pos, dim=-1)
start_neg, rest_neg = neg_rw_mapped[:, 0], neg_rw_mapped[:, 1:].contiguous()
h_start_neg = F.embedding(start_neg, embeddings).view(
neg_rw_mapped.size(0), 1, self.hidden_channels
)
h_rest_neg = F.embedding(rest_neg.view(-1), embeddings).view(
neg_rw_mapped.size(0), -1, self.hidden_channels
)
out_neg = (h_start_neg * h_rest_neg).sum(dim=-1)
neg_loss = torch.logsumexp(out_neg, dim=-1)
neg_loss = torch.logsumexp(
torch.cat((neg_loss.view(-1, 1), pos_loss.view(-1, 1)), dim=-1),
dim=-1
)
total = -torch.mean(torch.exp(pos_loss - neg_loss))
return LossOutput(
total=total,
components={
'pos': pos_loss.mean().item(),
'neg': neg_loss.mean().item()
}
)
[docs] def train_epoch(
self,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
epoch: int,
verbose: bool = True
) -> float:
r"""
Runs one epoch of S3GC training using the custom DataLoader.
Args:
loader (DataLoader): S3GC data loader created via :meth:`loader`.
optimizer (torch.optim.Optimizer): The optimizer.
epoch (int): Current epoch number.
verbose (bool, optional): If :obj:`True`, prints training progress.
(default: :obj:`True`)
Returns:
Average loss value of the epoch.
"""
self.train()
if self._AX.numel() == 0 or self._SX.numel() == 0:
raise RuntimeError(
"Must call `set_precomputed_features()` before training."
)
device = next(self.parameters()).device
mapping = torch.zeros(self.num_nodes, dtype=torch.long, device=device)
total_loss = 0.0
total_pos = 0.0
total_neg = 0.0
num_batches = 0
if verbose:
from tqdm import tqdm
pbar = tqdm(total=len(loader))
pbar.set_description(f'Epoch {epoch:02d}')
for pos_rw, neg_rw in loader:
optimizer.zero_grad()
pos_rw = pos_rw.to(device)
neg_rw = neg_rw.to(device)
# Get unique nodes in this batch
unique = torch.unique(torch.cat((pos_rw, neg_rw), dim=-1))
mapping.scatter_(0, unique, torch.arange(unique.size(0), device=device))
loss_output = self.loss(pos_rw, neg_rw, mapping)
loss_output.total.backward()
optimizer.step()
total_loss += loss_output.total.item()
total_pos += loss_output.components['pos']
total_neg += loss_output.components['neg']
num_batches += 1
if verbose:
pbar.update(1)
if verbose:
pbar.close()
avg_loss = total_loss / num_batches
avg_pos = total_pos / num_batches
avg_neg = total_neg / num_batches
if verbose:
print(f"Epoch: {epoch:02d} Loss: {avg_loss:.4f}, POS: {avg_pos:.4f}, NEG: {avg_neg:.4f}")
return avg_loss
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}('
f'num_nodes={self.num_nodes}, '
f'scale={self.scale}, '
f'in_channels={self.in_channels}, '
f'hidden_channels={self.hidden_channels})'
)
def precompute_features(
x: Tensor,
edge_index: Tensor,
num_nodes: int,
method: Literal['ppr', 'heat', 'custom'] = 'ppr',
k_hop: int = 2,
coefs: Optional[Tensor] = None,
add_self_loops: bool = True,
verbose: bool = True,
**kwargs
) -> Tuple[Tensor, Tensor]:
r"""
Precomputes :math:`\tilde{A}X` and :math:`S_kX` for S3GC.
This function should be called before training to avoid recomputing
these matrices in every forward pass.
Args:
x (Tensor): Node features of shape :obj:`(num_nodes, num_features)`.
edge_index (Tensor): Edge indices of shape :obj:`(2, num_edges)`.
num_nodes (int): Number of nodes in the graph.
method (str, optional): Diffusion method, one of :obj:`['ppr', 'heat', 'custom']`.
(default: :obj:`'ppr'`)
k_hop (int, optional): Number of hops for diffusion. (default: :obj:`2`)
coefs (Tensor, optional): Custom diffusion coefficients for :obj:`method='custom'`.
(default: :obj:`None`)
add_self_loops (bool, optional): Whether to add self-loops. (default: :obj:`True`)
verbose (bool, optional): Whether to print progress. (default: :obj:`True`)
**kwargs: Additional arguments for specific diffusion methods:
- :obj:`alpha` (float): PPR teleport probability (default: 0.2)
- :obj:`t` (float): Heat diffusion time (default: 5.0)
Returns:
Tuple of (:math:`\tilde{A}X`, :math:`S_kX`), both of shape
:obj:`(num_nodes, num_features)`.
Example:
>>> from pyagc.models.s3gc import precompute_features
>>> from pyagc.data import get_dataset
>>>
>>> # Load data
>>> x, edge_index, y = get_dataset('Cora', root='./data')
>>>
>>> # PPR diffusion (default)
>>> AX, SX = precompute_features(x, edge_index, x.size(0), method='ppr', k_hop=2)
>>>
>>> # Heat diffusion
>>> AX, SX = precompute_features(x, edge_index, x.size(0), method='heat', k_hop=3, t=5.0)
>>>
>>> # Custom weights
>>> weights = torch.tensor([0.5, 0.3, 0.2])
>>> AX, SX = precompute_features(x, edge_index, x.size(0), method='custom',
... k_hop=2, alpha=weights)
"""
if verbose:
print(f"Precomputing features using {method} diffusion with k={k_hop}...")
# Compute normalized adjacency
if verbose:
print("Computing normalized adjacency matrix...")
normalized_adj = compute_normalized_adjacency(
edge_index, num_nodes, add_self_loops
)
# Compute AX
if verbose:
print("Computing AX...")
AX = torch.sparse.mm(normalized_adj, x)
# Compute diffusion matrix SX
if verbose:
print(f"Computing diffusion matrix SX with {method} method...")
SX = compute_diffusion_matrix(
normalized_adj, x,
k=k_hop,
method=method,
coefs=coefs,
add_self_loops=add_self_loops,
**kwargs
)
if verbose:
print("Feature precomputation completed!")
return AX, SX
class CompositeOptimizer:
r"""
A wrapper optimizer that combines multiple optimizers into one.
It exposes a unified interface (`zero_grad`, `step`, `state_dict`,
`load_state_dict`) so it can be used exactly like a single PyTorch
optimizer. It does not inherit from `Optimizer` to keep the
implementation lightweight and flexible.
The internal optimizers can be of any type (Adam, SparseAdam, SGD, ...).
Example:
>>> sparse_opt = torch.optim.SparseAdam(...)
>>> dense_opt = torch.optim.Adam(...)
>>> optimizer = CompositeOptimizer(sparse=sparse_opt, dense=dense_opt)
"""
def __init__(self, **optimizers: torch.optim.Optimizer):
r"""
Args:
**optimizers: Arbitrary number of optimizers passed as
keyword arguments. The key names will be preserved
in `state_dict()` for saving/loading.
"""
self.optimizers = optimizers # dict: name -> optimizer
def zero_grad(self):
r"""Clears gradients of all wrapped optimizers."""
for opt in self.optimizers.values():
opt.zero_grad()
def step(self):
r"""Performs a step for each wrapped optimizer."""
for opt in self.optimizers.values():
opt.step()
def state_dict(self):
r"""
Returns a state_dict containing the state of all
wrapped optimizers, keyed by the names provided at initialization.
"""
return {name: opt.state_dict() for name, opt in self.optimizers.items()}
def load_state_dict(self, state_dict):
r"""
Loads the state_dict for each wrapped optimizer.
Args:
state_dict (dict): A state dictionary produced by `state_dict()`.
"""
for name, opt in self.optimizers.items():
if name not in state_dict:
raise KeyError(f"Missing optimizer state for key '{name}'")
opt.load_state_dict(state_dict[name])