pyagc.utils
The pyagc.utils package provides utility functions and classes for experiment management,
including checkpoint management for single-stage and multi-stage training, configuration loading,
logging, reproducibility, and common mathematical operations.
from pyagc.utils import (
CheckpointManager,
MultiStageCheckpointManager,
set_seed,
get_training_config,
get_logger,
)
# Set random seeds for reproducibility:
set_seed(42)
# Load dataset-specific training config from YAML:
config = get_training_config("Cora", config_path="train.conf.yaml")
# Create a logger with file and console output:
logger = get_logger("experiment.log", log_level=1)
# Manage checkpoints during training:
ckpt_mgr = CheckpointManager(ckpt_dir="./checkpoints", model_name="dmon")
ckpt_mgr.save_checkpoint(model, optimizer, epoch=10, loss=0.35, is_best=True)
Checkpoint Management
PyAGC provides two checkpoint managers to handle model persistence during training.
CheckpointManager supports standard single-stage training workflows,
while MultiStageCheckpointManager extends it for multi-stage pipelines
common in decoupled AGC methods (e.g., pre-training followed by fine-tuning).
Both managers automatically track the best model, support intra-epoch saving for mini-batch training on large graphs, and allow seamless training resumption.
from pyagc.utils import CheckpointManager, MultiStageCheckpointManager
# Single-stage checkpoint management:
ckpt = CheckpointManager("./ckpts", "dmon")
ckpt.save_checkpoint(model, optimizer, epoch=5, loss=0.42, is_best=True)
ckpt.load_checkpoint(model, optimizer, load_best=True, device="cuda")
# Multi-stage checkpoint management (e.g., pretrain + finetune):
ckpt = MultiStageCheckpointManager(
"./ckpts", "daegc", stages=["pretrain", "finetune"]
)
ckpt.save_checkpoint(model, optimizer, epoch=100, loss=0.5, stage="pretrain", is_best=True)
ckpt.load_checkpoint(model, stage="pretrain", load_best=True, device="cuda")
ckpt.save_checkpoint(model, optimizer, epoch=50, loss=0.3, stage="finetune", is_best=True)
Manages model checkpoints with support for resuming training. |
|
Checkpoint manager for multi-stage training. |
Configuration & Logging
PyAGC adopts a configuration-driven experiment design. All hyperparameters are
specified in YAML files with a hierarchical structure: a default section provides
base configurations, and dataset-specific sections selectively override these defaults.
# train.conf.yaml
default:
learning_rate: 0.001
hidden_dim: 128
model:
num_layers: 2
dropout: 0.5
Cora:
learning_rate: 0.01
model:
num_layers: 3
CiteSeer:
hidden_dim: 256
from pyagc.utils import get_training_config, get_logger
# Load merged configuration (default + dataset-specific overrides):
config = get_training_config("Cora", config_path="train.conf.yaml")
# >>> {'learning_rate': 0.01, 'hidden_dim': 128, 'model': {'num_layers': 3, 'dropout': 0.5}}
# Create a logger with both file and console output:
logger = get_logger("logs/experiment.log", log_level=1, name="pyagc")
logger.info("Training started")
Load training configuration from a YAML file with dataset-specific overrides. |
|
Create and configure a logger with both file and console handlers. |
|
Recursively update a nested dictionary. |
- get_training_config(dataset: str, config_path: str = 'train.conf.yaml') dict[source]
Load training configuration from a YAML file with dataset-specific overrides.
This function loads a hierarchical configuration file where a ‘default’ section provides base configurations and dataset-specific sections override these defaults. The merge is performed using deep dictionary updates to preserve nested structure.
The configuration file should follow this structure:
default: learning_rate: 0.001 hidden_dim: 128 model: num_layers: 2 dropout: 0.5 Cora: learning_rate: 0.01 model: num_layers: 3 CiteSeer: hidden_dim: 256
- Parameters:
- Returns:
- Merged configuration dictionary where dataset-specific values
override default values. Nested dictionaries are recursively merged.
- Return type:
- Raises:
FileNotFoundError – If the configuration file does not exist.
yaml.YAMLError – If the configuration file contains invalid YAML syntax.
Example
>>> # Given train.conf.yaml: >>> # default: >>> # lr: 0.001 >>> # hidden: 128 >>> # Cora: >>> # lr: 0.01 >>> config = get_training_config('Cora') >>> print(config) {'lr': 0.01, 'hidden': 128}
Note
If the dataset is not found in the config file, only default configuration is returned
Nested dictionaries are merged recursively via
deep_update()This function does not validate configuration values
- get_logger(filename: str, log_level: int = 1, name: Optional[str] = None, mode: str = 'a') Logger[source]
Create and configure a logger with both file and console handlers.
Sets up a logger that writes to both a file and the console (stdout) with consistent formatting. The logger can be configured with different verbosity levels and can append to or overwrite existing log files.
- Parameters:
filename (str) – Path to the log file. Parent directories will NOT be created automatically.
log_level (int, optional) –
Logging verbosity level:
0: DEBUG (most verbose)1: INFO (default)2: WARNING (least verbose)
(default:
1)name (str, optional) – Name for the logger. If
None, uses the root logger. Use different names to maintain separate loggers. (default:None)mode (str, optional) –
File opening mode:
'a': Append to existing file (default)'w': Overwrite existing file
(default:
'a')
- Returns:
- Configured logger instance with both file and console
handlers attached.
- Return type:
Note
The log format is:
'%(asctime)s - %(filename)s - %(levelname)s - %(message)s'Existing handlers are removed before adding new ones to avoid duplicates
Both file and console handlers use the same formatting
The logger is returned but also accessible via
logging.getLogger(name)
- deep_update_dict(base: dict, overrides: dict) dict[source]
Recursively update a nested dictionary.
Performs a deep merge of two dictionaries, where values from
overridesare recursively merged intobase. For nested dictionaries, this function recursively updates the nested structure. For non-dict values, the override value replaces the base value.- Parameters:
- Returns:
The updated base dictionary (same object as input, modified in-place).
- Return type:
Example
>>> base = {'a': 1, 'b': {'c': 2, 'd': 3}} >>> overrides = {'b': {'c': 20, 'e': 4}, 'f': 5} >>> result = deep_update_dict(base, overrides) >>> print(result) {'a': 1, 'b': {'c': 20, 'd': 3, 'e': 4}, 'f': 5}
Note
This function modifies
basein-placeFor nested dictionaries, only dict values are recursively merged
Non-dict values in
overridesalways replace values inbaseTo preserve the original
base, passbase.copy()
Reproducibility
Set random seeds for reproducibility across multiple libraries. |
Mathematical Utilities
Common mathematical operations used across the library, including distance computation and matrix manipulation.
from pyagc.utils import pairwise_squared_distance, off_diagonal
# Compute pairwise squared Euclidean distances (e.g., for KMeans):
x = torch.randn(1000, 128) # node embeddings
centers = torch.randn(7, 128) # cluster centers
dists = pairwise_squared_distance(x, centers) # (1000, 7)
# Extract off-diagonal elements (e.g., for regularization losses):
corr = torch.randn(128, 128)
off_diag = off_diagonal(corr) # (128 * 127,)
Compute pairwise squared Euclidean distances between two sets of vectors. |
|
Extract off-diagonal elements from a square matrix. |
|
Filter keyword arguments based on function signature. |
- pairwise_squared_distance(x: Tensor, y: Tensor) Tensor[source]
Compute pairwise squared Euclidean distances between two sets of vectors.
Efficiently computes the squared \(L_2\) distance between all pairs of vectors from two sets using the identity:
\[\| \mathbf{x}_i - \mathbf{y}_j \|_2^2 = \| \mathbf{x}_i \|_2^2 - 2 \mathbf{x}_i^\top \mathbf{y}_j + \| \mathbf{y}_j \|_2^2\]where \(\mathbf{x}_i\) and \(\mathbf{y}_j\) are the \(i\)-th and \(j\)-th vectors in sets \(x\) and \(y\) respectively.
This vectorized implementation is more efficient than naive nested loops and is commonly used in clustering algorithms (e.g., K-Means) and nearest neighbor computations.
- Parameters:
x (Tensor) – First set of vectors of shape
(B, D), where \(B\) is the number of samples and \(D\) is the feature dimension.y (Tensor) – Second set of vectors of shape
(K, D), where \(K\) is the number of reference points (e.g., cluster centers).
- Returns:
- Squared distance matrix of shape
(B, K), where element [i, j]contains \(\| \mathbf{x}_i - \mathbf{y}_j \|_2^2\).
- Squared distance matrix of shape
- Return type:
Tensor
Example
>>> x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) # 2 samples >>> y = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) # 2 centers >>> distances = pairwise_squared_distance(x, y) >>> print(distances) tensor([[ 5., 2.], [25., 13.]])
- off_diagonal(x: Tensor) Tensor[source]
Extract off-diagonal elements from a square matrix.
Returns a flattened view of all off-diagonal elements of a square matrix. This is useful for computing losses or metrics that exclude the diagonal, such as off-diagonal regularization in self-supervised learning.
- Parameters:
x (Tensor) – A square matrix of shape
(n, n).- Returns:
- Flattened tensor of shape
(n * (n-1),)containing all off-diagonal elements in row-major order.
- Flattened tensor of shape
- Return type:
Tensor
- Raises:
AssertionError – If the input is not a square matrix.
Example
>>> x = torch.tensor([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> off_diagonal(x) tensor([2, 3, 4, 6, 7, 8])
Note
This function is memory-efficient as it returns a view rather than a copy of the data when possible.
- filter_kwargs(func: Callable, kwargs: dict) dict[source]
Filter keyword arguments based on function signature.
This utility function inspects a function’s signature and returns only the keyword arguments that are valid parameters for that function. This is useful for passing flexible kwargs to functions without worrying about unsupported parameters.
- Parameters:
func (callable) – The function whose signature will be inspected.
kwargs (dict) – Dictionary of keyword arguments to filter.
- Returns:
- A filtered dictionary containing only the keys that match
the function’s parameter names.
- Return type:
Example
>>> def my_func(a, b, c=3): ... return a + b + c >>> kwargs = {'a': 1, 'b': 2, 'c': 3, 'd': 4} >>> filtered = filter_kwargs(my_func, kwargs) >>> print(filtered) {'a': 1, 'b': 2, 'c': 3} >>> my_func(**filtered) 6