pyagc.utils

The pyagc.utils package provides utility functions and classes for experiment management, including checkpoint management for single-stage and multi-stage training, configuration loading, logging, reproducibility, and common mathematical operations.

from pyagc.utils import (
    CheckpointManager,
    MultiStageCheckpointManager,
    set_seed,
    get_training_config,
    get_logger,
)

# Set random seeds for reproducibility:
set_seed(42)

# Load dataset-specific training config from YAML:
config = get_training_config("Cora", config_path="train.conf.yaml")

# Create a logger with file and console output:
logger = get_logger("experiment.log", log_level=1)

# Manage checkpoints during training:
ckpt_mgr = CheckpointManager(ckpt_dir="./checkpoints", model_name="dmon")
ckpt_mgr.save_checkpoint(model, optimizer, epoch=10, loss=0.35, is_best=True)

Checkpoint Management

PyAGC provides two checkpoint managers to handle model persistence during training. CheckpointManager supports standard single-stage training workflows, while MultiStageCheckpointManager extends it for multi-stage pipelines common in decoupled AGC methods (e.g., pre-training followed by fine-tuning).

Both managers automatically track the best model, support intra-epoch saving for mini-batch training on large graphs, and allow seamless training resumption.

from pyagc.utils import CheckpointManager, MultiStageCheckpointManager

# Single-stage checkpoint management:
ckpt = CheckpointManager("./ckpts", "dmon")
ckpt.save_checkpoint(model, optimizer, epoch=5, loss=0.42, is_best=True)
ckpt.load_checkpoint(model, optimizer, load_best=True, device="cuda")

# Multi-stage checkpoint management (e.g., pretrain + finetune):
ckpt = MultiStageCheckpointManager(
    "./ckpts", "daegc", stages=["pretrain", "finetune"]
)
ckpt.save_checkpoint(model, optimizer, epoch=100, loss=0.5, stage="pretrain", is_best=True)
ckpt.load_checkpoint(model, stage="pretrain", load_best=True, device="cuda")
ckpt.save_checkpoint(model, optimizer, epoch=50, loss=0.3, stage="finetune", is_best=True)

CheckpointManager

Manages model checkpoints with support for resuming training.

MultiStageCheckpointManager

Checkpoint manager for multi-stage training.

Configuration & Logging

PyAGC adopts a configuration-driven experiment design. All hyperparameters are specified in YAML files with a hierarchical structure: a default section provides base configurations, and dataset-specific sections selectively override these defaults.

# train.conf.yaml
default:
  learning_rate: 0.001
  hidden_dim: 128
  model:
    num_layers: 2
    dropout: 0.5

Cora:
  learning_rate: 0.01
  model:
    num_layers: 3

CiteSeer:
  hidden_dim: 256
from pyagc.utils import get_training_config, get_logger

# Load merged configuration (default + dataset-specific overrides):
config = get_training_config("Cora", config_path="train.conf.yaml")
# >>> {'learning_rate': 0.01, 'hidden_dim': 128, 'model': {'num_layers': 3, 'dropout': 0.5}}

# Create a logger with both file and console output:
logger = get_logger("logs/experiment.log", log_level=1, name="pyagc")
logger.info("Training started")

get_training_config

Load training configuration from a YAML file with dataset-specific overrides.

get_logger

Create and configure a logger with both file and console handlers.

deep_update_dict

Recursively update a nested dictionary.

get_training_config(dataset: str, config_path: str = 'train.conf.yaml') dict[source]

Load training configuration from a YAML file with dataset-specific overrides.

This function loads a hierarchical configuration file where a ‘default’ section provides base configurations and dataset-specific sections override these defaults. The merge is performed using deep dictionary updates to preserve nested structure.

The configuration file should follow this structure:

default:
  learning_rate: 0.001
  hidden_dim: 128
  model:
    num_layers: 2
    dropout: 0.5

Cora:
  learning_rate: 0.01
  model:
    num_layers: 3

CiteSeer:
  hidden_dim: 256
Parameters:
  • dataset (str) – Name of the dataset. Should match a top-level key in the configuration file (case-sensitive).

  • config_path (str, optional) – Path to the YAML configuration file. (default: 'train.conf.yaml')

Returns:

Merged configuration dictionary where dataset-specific values

override default values. Nested dictionaries are recursively merged.

Return type:

dict

Raises:
  • FileNotFoundError – If the configuration file does not exist.

  • yaml.YAMLError – If the configuration file contains invalid YAML syntax.

Example

>>> # Given train.conf.yaml:
>>> # default:
>>> #   lr: 0.001
>>> #   hidden: 128
>>> # Cora:
>>> #   lr: 0.01
>>> config = get_training_config('Cora')
>>> print(config)
{'lr': 0.01, 'hidden': 128}

Note

  • If the dataset is not found in the config file, only default configuration is returned

  • Nested dictionaries are merged recursively via deep_update()

  • This function does not validate configuration values

get_logger(filename: str, log_level: int = 1, name: Optional[str] = None, mode: str = 'a') Logger[source]

Create and configure a logger with both file and console handlers.

Sets up a logger that writes to both a file and the console (stdout) with consistent formatting. The logger can be configured with different verbosity levels and can append to or overwrite existing log files.

Parameters:
  • filename (str) – Path to the log file. Parent directories will NOT be created automatically.

  • log_level (int, optional) –

    Logging verbosity level:

    • 0: DEBUG (most verbose)

    • 1: INFO (default)

    • 2: WARNING (least verbose)

    (default: 1)

  • name (str, optional) – Name for the logger. If None, uses the root logger. Use different names to maintain separate loggers. (default: None)

  • mode (str, optional) –

    File opening mode:

    • 'a': Append to existing file (default)

    • 'w': Overwrite existing file

    (default: 'a')

Returns:

Configured logger instance with both file and console

handlers attached.

Return type:

logging.Logger

Note

  • The log format is: '%(asctime)s - %(filename)s - %(levelname)s - %(message)s'

  • Existing handlers are removed before adding new ones to avoid duplicates

  • Both file and console handlers use the same formatting

  • The logger is returned but also accessible via logging.getLogger(name)

deep_update_dict(base: dict, overrides: dict) dict[source]

Recursively update a nested dictionary.

Performs a deep merge of two dictionaries, where values from overrides are recursively merged into base. For nested dictionaries, this function recursively updates the nested structure. For non-dict values, the override value replaces the base value.

Parameters:
  • base (dict) – The base dictionary to be updated. This dictionary is modified in-place.

  • overrides (dict) – Dictionary containing values to merge into base. Nested dictionaries are recursively merged.

Returns:

The updated base dictionary (same object as input, modified in-place).

Return type:

dict

Example

>>> base = {'a': 1, 'b': {'c': 2, 'd': 3}}
>>> overrides = {'b': {'c': 20, 'e': 4}, 'f': 5}
>>> result = deep_update_dict(base, overrides)
>>> print(result)
{'a': 1, 'b': {'c': 20, 'd': 3, 'e': 4}, 'f': 5}

Note

  • This function modifies base in-place

  • For nested dictionaries, only dict values are recursively merged

  • Non-dict values in overrides always replace values in base

  • To preserve the original base, pass base.copy()

Reproducibility

set_seed

Set random seeds for reproducibility across multiple libraries.

set_seed(seed: int) None[source]

Set random seeds for reproducibility across multiple libraries.

Return type:

None

Mathematical Utilities

Common mathematical operations used across the library, including distance computation and matrix manipulation.

from pyagc.utils import pairwise_squared_distance, off_diagonal

# Compute pairwise squared Euclidean distances (e.g., for KMeans):
x = torch.randn(1000, 128)   # node embeddings
centers = torch.randn(7, 128)  # cluster centers
dists = pairwise_squared_distance(x, centers)  # (1000, 7)

# Extract off-diagonal elements (e.g., for regularization losses):
corr = torch.randn(128, 128)
off_diag = off_diagonal(corr)  # (128 * 127,)

pairwise_squared_distance

Compute pairwise squared Euclidean distances between two sets of vectors.

off_diagonal

Extract off-diagonal elements from a square matrix.

filter_kwargs

Filter keyword arguments based on function signature.

pairwise_squared_distance(x: Tensor, y: Tensor) Tensor[source]

Compute pairwise squared Euclidean distances between two sets of vectors.

Efficiently computes the squared \(L_2\) distance between all pairs of vectors from two sets using the identity:

\[\| \mathbf{x}_i - \mathbf{y}_j \|_2^2 = \| \mathbf{x}_i \|_2^2 - 2 \mathbf{x}_i^\top \mathbf{y}_j + \| \mathbf{y}_j \|_2^2\]

where \(\mathbf{x}_i\) and \(\mathbf{y}_j\) are the \(i\)-th and \(j\)-th vectors in sets \(x\) and \(y\) respectively.

This vectorized implementation is more efficient than naive nested loops and is commonly used in clustering algorithms (e.g., K-Means) and nearest neighbor computations.

Parameters:
  • x (Tensor) – First set of vectors of shape (B, D), where \(B\) is the number of samples and \(D\) is the feature dimension.

  • y (Tensor) – Second set of vectors of shape (K, D), where \(K\) is the number of reference points (e.g., cluster centers).

Returns:

Squared distance matrix of shape (B, K), where element

[i, j] contains \(\| \mathbf{x}_i - \mathbf{y}_j \|_2^2\).

Return type:

Tensor

Example

>>> x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])  # 2 samples
>>> y = torch.tensor([[0.0, 0.0], [1.0, 1.0]])  # 2 centers
>>> distances = pairwise_squared_distance(x, y)
>>> print(distances)
tensor([[ 5.,  2.],
        [25., 13.]])
off_diagonal(x: Tensor) Tensor[source]

Extract off-diagonal elements from a square matrix.

Returns a flattened view of all off-diagonal elements of a square matrix. This is useful for computing losses or metrics that exclude the diagonal, such as off-diagonal regularization in self-supervised learning.

Parameters:

x (Tensor) – A square matrix of shape (n, n).

Returns:

Flattened tensor of shape (n * (n-1),) containing

all off-diagonal elements in row-major order.

Return type:

Tensor

Raises:

AssertionError – If the input is not a square matrix.

Example

>>> x = torch.tensor([[1, 2, 3],
...                   [4, 5, 6],
...                   [7, 8, 9]])
>>> off_diagonal(x)
tensor([2, 3, 4, 6, 7, 8])

Note

This function is memory-efficient as it returns a view rather than a copy of the data when possible.

filter_kwargs(func: Callable, kwargs: dict) dict[source]

Filter keyword arguments based on function signature.

This utility function inspects a function’s signature and returns only the keyword arguments that are valid parameters for that function. This is useful for passing flexible kwargs to functions without worrying about unsupported parameters.

Parameters:
  • func (callable) – The function whose signature will be inspected.

  • kwargs (dict) – Dictionary of keyword arguments to filter.

Returns:

A filtered dictionary containing only the keys that match

the function’s parameter names.

Return type:

dict

Example

>>> def my_func(a, b, c=3):
...     return a + b + c
>>> kwargs = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
>>> filtered = filter_kwargs(my_func, kwargs)
>>> print(filtered)
{'a': 1, 'b': 2, 'c': 3}
>>> my_func(**filtered)
6