pyagc.utils.MultiStageCheckpointManager

class MultiStageCheckpointManager(ckpt_dir: str, model_name: str, stages: Optional[list] = None, logger=None)[source]

Bases: CheckpointManager

Checkpoint manager for multi-stage training.

__init__(ckpt_dir: str, model_name: str, stages: Optional[list] = None, logger=None)[source]
Parameters:
  • ckpt_dir (str) – Directory to save checkpoints

  • model_name (str) – Base name for checkpoint files

  • stages (Optional[list], default: None) – List of stage names (e.g., [‘pretrain’, ‘finetune’])

  • logger – Logger instance for logging

Methods

__init__(ckpt_dir, model_name[, stages, logger])

type ckpt_dir:

str

has_checkpoint([stage, load_best])

Check if a stage-specific checkpoint exists.

load_checkpoint(model[, optimizer, stage, ...])

Load a checkpoint and restore training state.

save_checkpoint(model, optimizer, epoch, loss)

Save a checkpoint with full training state.

save_checkpoint(model, optimizer, epoch: int, loss: float, stage: str = 'pretrain', is_best: bool = False, batch_idx: Optional[int] = None, additional_info: Optional[Dict[str, Any]] = None)[source]

Save a checkpoint with full training state.

Parameters:
  • model – The model to save

  • optimizer – The optimizer state

  • epoch (int) – Current epoch number

  • loss (float) – Current loss value

  • stage (str) – Training stage (‘pretrain’ or ‘finetune’)

  • is_best (bool) – Whether this is the best model so far

  • batch_idx (int, optional) – Current batch index within epoch

  • additional_info (dict) – Additional information to save

load_checkpoint(model, optimizer=None, stage: str = 'pretrain', load_best: bool = True, device: Optional[Union[str, device]] = 'cpu')[source]

Load a checkpoint and restore training state.

Parameters:
  • model – The model to load weights into

  • optimizer – The optimizer to restore state (optional)

  • stage (str) – Training stage (‘pretrain’ or ‘finetune’)

  • load_best (bool) – If True, load best checkpoint; otherwise load last

  • device (Union[str, device, None], default: 'cpu') – Device to map the checkpoint to

Returns:

Dictionary with checkpoint information or None if not found

has_checkpoint(stage: str = 'pretrain', load_best: bool = True) bool[source]

Check if a stage-specific checkpoint exists.

Parameters:
  • stage (str) – Training stage (‘pretrain’ or ‘finetune’)

  • load_best (bool) – If True, check for best checkpoint; otherwise check last

Returns:

True if checkpoint exists

Return type:

bool