Checkpoints & Recovery

Saving and resuming training state

Overview

malagent automatically saves training state, enabling recovery from crashes and comparison between cycles.

Directory Structure

output/
├── sft/
│   ├── checkpoint-500/       # Intermediate SFT checkpoints
│   ├── checkpoint-1000/
│   ├── checkpoint-1500/
│   └── final/                # Final SFT model
│       ├── adapter_config.json
│       ├── adapter_model.safetensors
│       ├── tokenizer.json
│       ├── tokenizer_config.json
│       └── training_args.bin
│
└── raft/
    ├── cycle_1/
    │   ├── checkpoint/       # Model after cycle 1
    │   ├── samples.jsonl     # Verified samples with rewards
    │   ├── training_state.json
    │   └── metrics.json
    ├── cycle_2/
    │   └── ...
    ├── cycle_6/
    │   └── ...
    ├── best/                 # Symlink to best checkpoint
    └── statistics.json       # Aggregate metrics

Checkpoint Contents

Model Checkpoint

checkpoint/
├── adapter_config.json       # LoRA configuration
├── adapter_model.safetensors # LoRA weights (~500MB)
├── tokenizer.json            # Tokenizer state
├── tokenizer_config.json     # Tokenizer config
├── special_tokens_map.json   # Special token mappings
├── optimizer.pt              # Optimizer state (~1GB)
├── scheduler.pt              # LR scheduler state
└── rng_state.pth             # Random number generator state

Training State

training_state.json:

{
  "cycle": 3,
  "global_step": 4500,
  "epoch": 1.0,
  "total_samples_seen": 13656,
  "best_compile_rate": 0.334,
  "best_cycle": 3,
  "learning_rate": 4.2e-5,
  "timestamp": "2024-01-15T14:32:00Z"
}

Metrics

metrics.json:

{
  "cycle": 3,
  "generated": 4552,
  "compiled": 1523,
  "compile_rate": 0.334,
  "filtered": 762,
  "training_loss": 0.285,
  "eval_loss": 0.312,
  "rewards": {
    "mean": 0.67,
    "std": 0.18,
    "distribution": {
      "0.0-0.2": 2729,
      "0.2-0.5": 300,
      "0.5-0.7": 523,
      "0.7-0.9": 712,
      "0.9-1.0": 288
    }
  },
  "duration_seconds": 10832
}

Samples

samples.jsonl:

{"prompt": "Write a function...", "completion": "#include...", "reward": 0.8, "compile_status": "success", "details": null}
{"prompt": "Implement...", "completion": "#include...", "reward": 0.1, "compile_status": "error", "details": "undeclared identifier 'NtAllocate'"}

Automatic Checkpointing

SFT Checkpoints

Saved every N steps (default: 500):

sft:
  training:
    save_steps: 500
    save_total_limit: 3    # Keep only last 3

RAFT Checkpoints

Saved after each cycle (always kept):

raft:
  checkpointing:
    save_after_cycle: true
    save_samples: true
    save_optimizer: true

Resuming Training

Automatically detects last checkpoint:

# SFT - resumes from latest checkpoint
python malagent/sft/trainer.py --resume

# RAFT - resumes from last completed cycle
python malagent/training/raft_trainer.py --resume

Manual Resume

Specify exact checkpoint:

# Resume from specific SFT checkpoint
python malagent/sft/trainer.py \
    --resume-from output/sft/checkpoint-1500

# Resume from specific RAFT cycle
python malagent/training/raft_trainer.py \
    --resume-from output/raft/cycle_3

Resume Behavior

When resuming RAFT from cycle N:

  1. Load model from cycle_N/checkpoint/
  2. Load optimizer/scheduler states
  3. Restore RNG state for reproducibility
  4. Continue from cycle N+1
  5. Preserve previous cycle statistics

Crash Recovery

Detecting Incomplete Cycles

If training crashes mid-cycle:

# Check for incomplete cycles
ls -la output/raft/*/training_state.json

# Incomplete cycle won't have training_state.json
# Resume will restart from last complete cycle

Manual Recovery

If automatic resume fails:

from malagent.training import RAFTTrainer

trainer = RAFTTrainer(config_path="configs/raft_config.yaml")

# Force load from specific checkpoint
trainer.load_checkpoint("output/raft/cycle_2/checkpoint")

# Continue training
trainer.run(start_cycle=3, num_cycles=6)

Comparing Checkpoints

Evaluate Specific Cycle

# Evaluate cycle 3 checkpoint
python malagent/benchmark/evaluate.py \
    --model output/raft/cycle_3/checkpoint \
    --prompts data/eval_prompts.jsonl \
    --samples 4

Compare Cycles

# Compare compile rates across cycles
for cycle in output/raft/cycle_*/; do
    echo "=== $cycle ==="
    jq '.compile_rate' "$cycle/metrics.json"
done

Best Checkpoint Selection

malagent maintains a best/ symlink to the best-performing cycle:

# Current best
ls -la output/raft/best
# -> cycle_6/checkpoint

# Use best for inference
python inference.py --model output/raft/best

Storage Management

Checkpoint Sizes

ComponentSize (7B model)
LoRA weights~500MB
Optimizer state~1GB
Samples (per cycle)~50-200MB
Full cycle~2GB
6 cycles total~12GB

Cleanup Old Checkpoints

# Keep only best and last 2 cycles
python scripts/cleanup_checkpoints.py \
    --keep-best \
    --keep-last 2 \
    --output-dir output/raft

Compress for Storage

# Archive a completed training run
tar -czvf training_run_2024-01-15.tar.gz \
    output/raft/best \
    output/raft/statistics.json \
    output/raft/*/metrics.json

Export for Deployment

Merge LoRA to Base Model

from peft import PeftModel
from transformers import AutoModelForCausalLM

# Load base model
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-7B")

# Load and merge LoRA
model = PeftModel.from_pretrained(base_model, "output/raft/best")
merged_model = model.merge_and_unload()

# Save merged model
merged_model.save_pretrained("output/merged_model")

Convert to GGUF (for llama.cpp)

# After merging
python llama.cpp/convert_hf_to_gguf.py \
    output/merged_model \
    --outfile output/malagent-7b.gguf \
    --outtype q8_0