from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Union
import torch
from torch import nn, Tensor
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
[docs]@dataclass
class LossOutput:
r"""
Unified loss output format for training.
This class encapsulates both the total loss (used for backpropagation)
and individual loss components (used for logging and monitoring).
Args:
total (Tensor): The total loss scalar for backpropagation.
components (Dict[str, float]): Dictionary of individual loss components
for logging purposes, e.g., :obj:`{'reconstruction': 0.5, 'kl': 0.3}`.
Example:
>>> loss_output = LossOutput(
... total=torch.tensor(1.5),
... components={'ali': 0.8, 'nei': 0.5, 'spa': 0.2}
... )
>>> print(loss_output.log_string("Epoch 01: "))
Epoch 01: Loss: 1.5000, ALI: 0.8000, NEI: 0.5000, SPA: 0.2000
"""
total: Tensor
components: Dict[str, float]
def __float__(self) -> float:
r"""Enables :obj:`float(loss_output)` for compatibility."""
return float(self.total.item())
[docs] def log_string(self, prefix: str = "") -> str:
r"""
Generates a formatted log string for printing.
Args:
prefix (str, optional): Prefix string to prepend. (default: :obj:`""`)
Returns:
Formatted string like :obj:`"Loss: 1.50, ALI: 0.80, NEI: 0.50"`.
"""
parts = [f"Loss: {self.total.item():.4f}"]
for name, value in self.components.items():
parts.append(f"{name.upper()}: {value:.4f}")
return f"{prefix}{', '.join(parts)}"
[docs]class BaseModel(ABC, nn.Module):
r"""
Base interface for all PyAGC models.
All models must implement :meth:`embed` to produce node embeddings.
For trainable models, inherit from :class:`TrainableModel` instead.
Example:
>>> class MyEncoder(BaseModel):
... def embed(self, x, edge_index):
... return self.encoder(x, edge_index)
See Also:
- :class:`TrainableModel`: For models with loss computation
- :class:`ClusteringModel`: For end-to-end clustering
"""
[docs] @abstractmethod
def embed(self, *args: Any, **kwargs: Any) -> Tensor:
r"""
Returns node embeddings.
Args:
For graph-based models: x, edge_index, ...
For lookup-based models: batch (node indices)
The output shape and type depend on the specific model implementation.
Typically returns a :obj:`Tensor` of shape :obj:`(num_nodes, hidden_dim)`.
"""
pass
[docs] def reset_parameters(self):
r"""Resets learnable parameters. Override when needed."""
pass
[docs] @torch.no_grad()
def infer_full(self, data: Data) -> Any:
r"""
Full-graph inference: returns embeddings or predictions for all nodes.
Args:
data (Data): Input graph data.
Returns:
Node embeddings or predictions, typically of shape :obj:`(num_nodes, *)`.
"""
self.eval()
return self.embed(**data)
[docs] @torch.no_grad()
def infer_batch(self, loader: NeighborLoader, verbose: bool = True) -> Any:
r"""
Mini-batch inference over a NeighborLoader.
For node-level outputs, concatenates only the seed nodes of each batch.
Args:
loader (NeighborLoader): Mini-batch data loader.
verbose (bool, optional): If :obj:`True`, displays a progress bar.
(default: :obj:`True`)
Returns:
Node embeddings or predictions for all nodes in the loader.
"""
self.eval()
all_z = []
device = next(self.parameters()).device if list(self.parameters()) else 'cpu'
if verbose:
from tqdm import tqdm
pbar = tqdm(total=loader.data.num_nodes)
pbar.set_description("Inference stage")
for batch in loader:
batch = batch.to(device)
z = self.embed(**batch)
all_z.append(z[:batch.batch_size].cpu())
if verbose:
pbar.update(batch.batch_size)
if verbose:
pbar.close()
return torch.cat(all_z, dim=0).to(device)
[docs]class TrainableModel(BaseModel):
r"""
Base class for trainable graph models.
Subclasses must implement the :meth:`loss` method. This class provides
default implementations for :meth:`train_full` and :meth:`train_batch`
that handle both single-loss and multi-component loss outputs.
The :meth:`loss` method can return either:
- A single :obj:`Tensor` for simple losses
- A :class:`LossOutput` object for losses with multiple components
"""
def __init__(self):
super().__init__()
self.logger = None # Can be set externally
[docs] def set_logger(self, logger):
r"""Sets a custom logger for training output."""
self.logger = logger
[docs] @abstractmethod
def loss(self, *args: Any, **kwargs: Any) -> Union[Tensor, LossOutput]:
r"""
Computes the training loss.
Returns:
Either a scalar loss :obj:`Tensor`, or a :class:`LossOutput` object
containing the total loss and individual components.
"""
pass
[docs] def train_full(self, data: Data, optimizer: torch.optim.Optimizer,
epoch: int, verbose: bool = True, **loss_kwargs: Any) -> float:
r"""
Runs one epoch of full-batch training.
Args:
data (Data): The input full graph data.
optimizer (torch.optim.Optimizer): The optimizer.
epoch (int): Current epoch number.
verbose (bool, optional): If :obj:`True`, prints training progress.
(default: :obj:`True`)
**loss_kwargs: Additional keyword arguments passed to :meth:`loss`.
Returns:
Loss value of the epoch.
"""
self.train()
optimizer.zero_grad()
# Merge data attributes with loss_kwargs
loss_output = self.loss(**{**data, **loss_kwargs})
# Handle both single Tensor and LossOutput returns
if isinstance(loss_output, LossOutput):
loss = loss_output.total
log_str = loss_output.log_string(f"Epoch: {epoch:03d} ")
else:
loss = loss_output
log_str = f"Epoch: {epoch:03d} Loss: {loss.item():.4f}"
loss.backward()
optimizer.step()
if verbose:
self.logger.info(log_str) if self.logger else print(log_str)
return float(loss.item())
[docs] def train_batch(self, loader: NeighborLoader, optimizer: torch.optim.Optimizer,
epoch: int, verbose: bool = True, **loss_kwargs: Any) -> float:
r"""
Runs one epoch of mini-batch training.
Args:
loader (NeighborLoader): The mini-batch loader.
optimizer (torch.optim.Optimizer): The optimizer.
epoch (int): Current epoch number.
verbose (bool, optional): If :obj:`True`, prints training progress.
(default: :obj:`True`)
**loss_kwargs: Additional keyword arguments passed to :meth:`loss_batch`.
Returns:
Average loss value of the epoch.
"""
self.train()
if loader.input_nodes is None:
num_nodes = loader.data.num_nodes
else:
num_nodes = loader.input_nodes.size(0)
if verbose:
from tqdm import tqdm
pbar = tqdm(total=num_nodes)
pbar.set_description(f'Epoch {epoch:03d}')
# Accumulate loss components across batches
total_loss = 0.0
components_sum = {}
device = next(self.parameters()).device
for batch in loader:
batch = batch.to(device)
optimizer.zero_grad()
# Pass loss_kwargs to loss_batch
loss_output = self.loss_batch(batch, **loss_kwargs)
# Handle both single Tensor and LossOutput returns
if isinstance(loss_output, LossOutput):
loss = loss_output.total
for name, value in loss_output.components.items():
components_sum[name] = components_sum.get(name, 0.0) + value
else:
loss = loss_output
loss.backward()
optimizer.step()
total_loss += loss.item()
if verbose:
pbar.update(batch.batch_size)
if verbose:
pbar.close()
# Compute average values
num_batches = len(loader)
avg_loss = total_loss / num_batches
# Print loss information
if verbose:
if components_sum:
parts = [f"Loss: {avg_loss:.4f}"]
for name, value in components_sum.items():
avg_value = value / num_batches
parts.append(f"{name.upper()}: {avg_value:.4f}")
log_str = f"Epoch: {epoch:03d} {', '.join(parts)}"
else:
log_str = f"Epoch: {epoch:03d} Loss: {avg_loss:.4f}"
self.logger.info(log_str) if self.logger else print(log_str)
return avg_loss
[docs] def loss_batch(self, batch: Data, **kwargs: Any) -> Union[Tensor, LossOutput]:
r"""
Computes loss for a mini-batch.
Default implementation simply calls :meth:`loss`. Subclasses can override
this method to handle batch-specific logic (e.g., slicing to seed nodes).
Args:
batch (Data): A mini-batch from the loader.
**kwargs: Additional keyword arguments.
Returns:
Loss output, same format as :meth:`loss`.
"""
return self.loss(**{**batch, **kwargs})
[docs]class ClusteringModel(TrainableModel):
r"""
Base class for end-to-end clustering models.
This class is designed for models that directly output cluster assignments
(e.g., DMoN, MinCut) rather than embeddings. It provides a unified interface
for clustering tasks by overriding the :meth:`infer_full` and :meth:`infer_batch`
methods to return cluster assignments directly.
Subclasses should implement:
- :meth:`embed`: Returns node embeddings
- :meth:`forward`: Returns hard cluster assignments
- :meth:`loss`: Computes clustering loss
"""
[docs] @abstractmethod
def forward(self, *args: Any, **kwargs: Any) -> Tensor:
r"""
Returns hard cluster assignments.
Returns:
Cluster assignments of shape :obj:`(num_nodes,)`.
"""
pass
[docs] @torch.no_grad()
def infer_full(self, data: Data) -> Tensor:
r"""
Full-graph inference: returns cluster assignments for all nodes.
Args:
data (Data): Input graph data.
Returns:
Cluster assignments of shape :obj:`(num_nodes,)`.
"""
self.eval()
return self.forward(**data)
[docs] @torch.no_grad()
def infer_batch(self, loader: NeighborLoader, verbose: bool = True) -> Tensor:
r"""
Mini-batch inference over a NeighborLoader.
Args:
loader (NeighborLoader): Mini-batch data loader.
verbose (bool, optional): If :obj:`True`, displays a progress bar.
(default: :obj:`True`)
Returns:
Cluster assignments for all nodes in the loader.
"""
self.eval()
all_pred = []
device = next(self.parameters()).device if list(self.parameters()) else 'cpu'
if verbose:
from tqdm import tqdm
pbar = tqdm(total=loader.data.num_nodes)
pbar.set_description("Inference stage")
for batch in loader:
batch = batch.to(device)
pred = self.forward(**batch)
all_pred.append(pred[:batch.batch_size])
if verbose:
pbar.update(batch.batch_size)
if verbose:
pbar.close()
return torch.cat(all_pred, dim=0)
[docs] @torch.no_grad()
def initialize_cluster_centers(self, data: Data, num_layers: int, train_idx: Tensor = None, batch_size: int = 4096,
fan_out: int = -1, method: str = 'kmeans', verbose: bool = True):
r"""
Initialize cluster centers using K-Means.
Supports two modes:
1. Small graphs: Use all nodes for initialization
2. Large graphs: Use only training nodes or mini-batch inference
Args:
data (Data): Full graph data.
num_layers (int): Number of encoder layers.
train_idx (torch.Tensor, optional): Training node indices.
If provided, only use these nodes for initialization.
(default: :obj:`None`)
batch_size (int, optional): Batch size for mini-batch inference.
(default: :obj:`4096`)
fan_out (int, optional): Number of sampled neighbors.
(default: :obj:`-1`)
method (str, optional): Initialization method. (default: :obj:`'kmeans'`)
verbose (bool, optional): If :obj:`True`, prints initializing progress.
(default: :obj:`True`)
"""
self.eval()
device = next(self.parameters()).device
with torch.no_grad():
if not isinstance(data, Data):
raise TypeError("data must be a torch_geometric.data.Data object")
# Determine which nodes to use for initialization
use_train_only = train_idx is not None and len(train_idx) < data.num_nodes
input_nodes = train_idx if use_train_only else None
num_nodes = len(train_idx) if use_train_only else data.num_nodes
if verbose:
log_str = f"Initializing cluster centers using {'subset' if use_train_only else 'all'} {num_nodes} nodes..."
self.logger.info(log_str) if self.logger else print(log_str)
# Try full-batch embedding first
try:
if use_train_only:
from torch_geometric.utils import subgraph
edge_index_subset, _ = subgraph(
train_idx,
data.edge_index,
relabel_nodes=True,
num_nodes=data.num_nodes
)
x = data.x[train_idx].to(device)
edge_index = edge_index_subset.to(device)
else:
x = data.x.to(device)
edge_index = data.edge_index.to(device)
z = self.embed(x, edge_index)
except RuntimeError as e:
# print(f"[Warning] Full-batch embedding failed: {e}")
# print(f"Using mini-batch inference (batch_size={batch_size})...")
# Mini-batch inference with NeighborLoader
from torch_geometric.loader import NeighborLoader
loader = NeighborLoader(
data,
input_nodes=input_nodes,
num_neighbors=[fan_out] * num_layers,
batch_size=batch_size,
shuffle=False
)
all_z = []
for batch in loader:
batch = batch.to(device)
z_batch = self.embed(batch.x, batch.edge_index)
all_z.append(z_batch[:batch.batch_size].cpu())
z = torch.cat(all_z, dim=0).to(device)
# Normalize embeddings
z = torch.nn.functional.normalize(z, p=2, dim=-1)
# Run K-Means
if method == 'kmeans':
# print(f"Running K-Means on {z.size(0)} embeddings...")
from pyagc.clusters.kmeans_cluster_head import KMeansClusterHead
kmeans = KMeansClusterHead(
n_clusters=self.n_clusters,
backend='torch',
n_init=1,
max_iter=300
)
kmeans.fit_predict(z)
# Set cluster centers
self.cluster_head.reset_cluster_centers(
kmeans.cluster_centers.detach().to(device)
)
else:
raise ValueError(f"Unknown initialization method: {method}")
# print(f"✓ Cluster centers initialized: shape={self.cluster_head.cluster_centers.shape}")