Source code for pyagc.utils.checkpoint

"""Checkpoint management utilities."""

import os
import torch
from typing import Optional, Dict, Any


[docs]class CheckpointManager: """Manages model checkpoints with support for resuming training."""
[docs] def __init__(self, ckpt_dir: str, model_name: str, logger=None): """ Args: ckpt_dir (str): Directory to save checkpoints model_name (str): Base name for checkpoint files logger: Logger instance for logging """ self.ckpt_dir = ckpt_dir self.model_name = model_name self.logger = logger os.makedirs(ckpt_dir, exist_ok=True) # Define checkpoint paths self.best_path = os.path.join(ckpt_dir, f"{model_name}_best.pt") self.last_path = os.path.join(ckpt_dir, f"{model_name}_last.pt") # Track best performance self.best_loss = float('inf')
[docs] def save_checkpoint(self, model, optimizer, epoch: int, loss: float, is_best: bool = False, batch_idx: Optional[int] = None, additional_info: Optional[Dict[str, Any]] = None): """Save a checkpoint with full training state. Args: model: The model to save optimizer: The optimizer state epoch (int): Current epoch number loss (float): Current loss value is_best (bool): Whether this is the best model so far batch_idx (int, optional): Current batch index within epoch (for intra-epoch saves) additional_info (dict): Additional information to save """ checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'best_loss': self.best_loss, } # Add batch information if provided (for intra-epoch checkpoints) if batch_idx is not None: checkpoint['batch_idx'] = batch_idx if additional_info is not None: checkpoint.update(additional_info) # Always save last checkpoint torch.save(checkpoint, self.last_path) # Save best checkpoint if this is the best model if is_best: torch.save(checkpoint, self.best_path) self.best_loss = loss if self.logger: batch_str = f" (batch {batch_idx})" if batch_idx is not None else "" self.logger.info(f"✓ Best model saved at epoch {epoch}{batch_str} with loss {loss:.4f}") if self.logger and batch_idx is not None: # Log for intra-epoch saves if batch_idx % 100 == 0: # Log every 100 batches to avoid spam self.logger.info(f"✓ Checkpoint saved at epoch {epoch}, batch {batch_idx}")
[docs] def load_checkpoint(self, model, optimizer=None, load_best: bool = True, device: Optional[str|torch.device] = 'cpu'): """Load a checkpoint and restore training state. Args: model: The model to load weights into optimizer: The optimizer to restore state (optional) load_best (bool): If True, load best checkpoint; otherwise load last device: Device to map the checkpoint to Returns: Dictionary with checkpoint information (epoch, loss, batch_idx, etc.) Returns None if checkpoint doesn't exist """ ckpt_path = self.best_path if load_best else self.last_path if not os.path.exists(ckpt_path): if self.logger: self.logger.info(f"No checkpoint found at {ckpt_path}") return None if self.logger: self.logger.info(f"Loading checkpoint from {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) # Restore model state model.load_state_dict(checkpoint['model_state_dict']) # Restore optimizer state if provided if optimizer is not None and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Update best loss tracking if 'best_loss' in checkpoint: self.best_loss = checkpoint['best_loss'] if self.logger: epoch = checkpoint.get('epoch', 'unknown') loss = checkpoint.get('loss', 'unknown') batch_idx = checkpoint.get('batch_idx', None) batch_str = f", batch={batch_idx}" if batch_idx is not None else "" self.logger.info(f"✓ Checkpoint loaded: epoch={epoch}{batch_str}, loss={loss}") return checkpoint
[docs] def has_checkpoint(self, load_best: bool = True) -> bool: """Check if a checkpoint exists. Args: load_best (bool): If True, check for best checkpoint; otherwise check last Returns: bool: True if checkpoint exists """ ckpt_path = self.best_path if load_best else self.last_path return os.path.exists(ckpt_path)
[docs]class MultiStageCheckpointManager(CheckpointManager): """Checkpoint manager for multi-stage training."""
[docs] def __init__(self, ckpt_dir: str, model_name: str, stages: list = None, logger=None): """ Args: ckpt_dir (str): Directory to save checkpoints model_name (str): Base name for checkpoint files stages: List of stage names (e.g., ['pretrain', 'finetune']) logger: Logger instance for logging """ super().__init__(ckpt_dir, model_name, logger) self.stages = stages or ['pretrain', 'finetune'] self.stage_paths = {} self.best_losses = {} # Create paths for each stage for stage in self.stages: self.stage_paths[stage] = { 'best': os.path.join(ckpt_dir, f"{model_name}_{stage}_best.pt"), 'last': os.path.join(ckpt_dir, f"{model_name}_{stage}_last.pt") } self.best_losses[stage] = float('inf')
[docs] def save_checkpoint(self, model, optimizer, epoch: int, loss: float, stage: str = 'pretrain', is_best: bool = False, batch_idx: Optional[int] = None, additional_info: Optional[Dict[str, Any]] = None): """Save a checkpoint with full training state. Args: model: The model to save optimizer: The optimizer state epoch (int): Current epoch number loss (float): Current loss value stage (str): Training stage ('pretrain' or 'finetune') is_best (bool): Whether this is the best model so far batch_idx (int, optional): Current batch index within epoch additional_info (dict): Additional information to save """ if stage not in self.stages: raise ValueError(f"Unknown stage: {stage}. Available stages: {self.stages}") checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'stage': stage, 'best_loss': self.best_losses[stage], } if batch_idx is not None: checkpoint['batch_idx'] = batch_idx if additional_info is not None: checkpoint.update(additional_info) # Save last checkpoint for this stage last_path = self.stage_paths[stage]['last'] torch.save(checkpoint, last_path) # Save best checkpoint if this is the best model if is_best: best_path = self.stage_paths[stage]['best'] torch.save(checkpoint, best_path) self.best_losses[stage] = loss if self.logger: batch_str = f" (batch {batch_idx})" if batch_idx is not None else "" self.logger.info( f"✓ Best {stage} model saved at epoch {epoch}{batch_str} with loss {loss:.4f}" )
[docs] def load_checkpoint(self, model, optimizer=None, stage: str = 'pretrain', load_best: bool = True, device: Optional[str|torch.device] = 'cpu'): """Load a checkpoint and restore training state. Args: model: The model to load weights into optimizer: The optimizer to restore state (optional) stage (str): Training stage ('pretrain' or 'finetune') load_best (bool): If True, load best checkpoint; otherwise load last device: Device to map the checkpoint to Returns: Dictionary with checkpoint information or None if not found """ if stage not in self.stages: raise ValueError(f"Unknown stage: {stage}. Available stages: {self.stages}") ckpt_path = self.stage_paths[stage]['best' if load_best else 'last'] if not os.path.exists(ckpt_path): if self.logger: self.logger.info(f"No {stage} checkpoint found at {ckpt_path}") return None if self.logger: self.logger.info(f"Loading {stage} checkpoint from {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) # Restore model state model.load_state_dict(checkpoint['model_state_dict']) # Restore optimizer state if provided if optimizer is not None and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Update best loss tracking if 'best_loss' in checkpoint: self.best_losses[stage] = checkpoint['best_loss'] if self.logger: epoch = checkpoint.get('epoch', 'unknown') loss = checkpoint.get('loss', 'unknown') batch_idx = checkpoint.get('batch_idx', None) batch_str = f", batch={batch_idx}" if batch_idx is not None else "" self.logger.info(f"✓ Checkpoint loaded: epoch={epoch}{batch_str}, loss={loss}") return checkpoint
[docs] def has_checkpoint(self, stage: str = 'pretrain', load_best: bool = True) -> bool: """Check if a stage-specific checkpoint exists. Args: stage (str): Training stage ('pretrain' or 'finetune') load_best (bool): If True, check for best checkpoint; otherwise check last Returns: bool: True if checkpoint exists """ if stage not in self.stages: return False ckpt_path = self.stage_paths[stage]['best' if load_best else 'last'] return os.path.exists(ckpt_path)