pyagc.utils.MultiStageCheckpointManager
- class MultiStageCheckpointManager(ckpt_dir: str, model_name: str, stages: Optional[list] = None, logger=None)[source]
Bases:
CheckpointManagerCheckpoint manager for multi-stage training.
Methods
__init__(ckpt_dir, model_name[, stages, logger])- type ckpt_dir:
has_checkpoint([stage, load_best])Check if a stage-specific checkpoint exists.
load_checkpoint(model[, optimizer, stage, ...])Load a checkpoint and restore training state.
save_checkpoint(model, optimizer, epoch, loss)Save a checkpoint with full training state.
- save_checkpoint(model, optimizer, epoch: int, loss: float, stage: str = 'pretrain', is_best: bool = False, batch_idx: Optional[int] = None, additional_info: Optional[Dict[str, Any]] = None)[source]
Save a checkpoint with full training state.
- Parameters:
model – The model to save
optimizer – The optimizer state
epoch (int) – Current epoch number
loss (float) – Current loss value
stage (str) – Training stage (‘pretrain’ or ‘finetune’)
is_best (bool) – Whether this is the best model so far
batch_idx (int, optional) – Current batch index within epoch
additional_info (dict) – Additional information to save
- load_checkpoint(model, optimizer=None, stage: str = 'pretrain', load_best: bool = True, device: Optional[Union[str, device]] = 'cpu')[source]
Load a checkpoint and restore training state.
- Parameters:
model – The model to load weights into
optimizer – The optimizer to restore state (optional)
stage (str) – Training stage (‘pretrain’ or ‘finetune’)
load_best (bool) – If True, load best checkpoint; otherwise load last
device (
Union[str,device,None], default:'cpu') – Device to map the checkpoint to
- Returns:
Dictionary with checkpoint information or None if not found