Checkpoints & Recovery
Saving and resuming training state
Overview
malagent automatically saves training state, enabling recovery from crashes and comparison between cycles.
Directory Structure
output/
├── sft/
│ ├── checkpoint-500/ # Intermediate SFT checkpoints
│ ├── checkpoint-1000/
│ ├── checkpoint-1500/
│ └── final/ # Final SFT model
│ ├── adapter_config.json
│ ├── adapter_model.safetensors
│ ├── tokenizer.json
│ ├── tokenizer_config.json
│ └── training_args.bin
│
└── raft/
├── cycle_1/
│ ├── checkpoint/ # Model after cycle 1
│ ├── samples.jsonl # Verified samples with rewards
│ ├── training_state.json
│ └── metrics.json
├── cycle_2/
│ └── ...
├── cycle_6/
│ └── ...
├── best/ # Symlink to best checkpoint
└── statistics.json # Aggregate metrics
Checkpoint Contents
Model Checkpoint
checkpoint/
├── adapter_config.json # LoRA configuration
├── adapter_model.safetensors # LoRA weights (~500MB)
├── tokenizer.json # Tokenizer state
├── tokenizer_config.json # Tokenizer config
├── special_tokens_map.json # Special token mappings
├── optimizer.pt # Optimizer state (~1GB)
├── scheduler.pt # LR scheduler state
└── rng_state.pth # Random number generator state
Training State
training_state.json:
{
"cycle": 3,
"global_step": 4500,
"epoch": 1.0,
"total_samples_seen": 13656,
"best_compile_rate": 0.334,
"best_cycle": 3,
"learning_rate": 4.2e-5,
"timestamp": "2024-01-15T14:32:00Z"
}
Metrics
metrics.json:
{
"cycle": 3,
"generated": 4552,
"compiled": 1523,
"compile_rate": 0.334,
"filtered": 762,
"training_loss": 0.285,
"eval_loss": 0.312,
"rewards": {
"mean": 0.67,
"std": 0.18,
"distribution": {
"0.0-0.2": 2729,
"0.2-0.5": 300,
"0.5-0.7": 523,
"0.7-0.9": 712,
"0.9-1.0": 288
}
},
"duration_seconds": 10832
}
Samples
samples.jsonl:
{"prompt": "Write a function...", "completion": "#include...", "reward": 0.8, "compile_status": "success", "details": null}
{"prompt": "Implement...", "completion": "#include...", "reward": 0.1, "compile_status": "error", "details": "undeclared identifier 'NtAllocate'"}
Automatic Checkpointing
SFT Checkpoints
Saved every N steps (default: 500):
sft:
training:
save_steps: 500
save_total_limit: 3 # Keep only last 3
RAFT Checkpoints
Saved after each cycle (always kept):
raft:
checkpointing:
save_after_cycle: true
save_samples: true
save_optimizer: true
Resuming Training
Auto-Resume (Recommended)
Automatically detects last checkpoint:
# SFT - resumes from latest checkpoint
python malagent/sft/trainer.py --resume
# RAFT - resumes from last completed cycle
python malagent/training/raft_trainer.py --resume
Manual Resume
Specify exact checkpoint:
# Resume from specific SFT checkpoint
python malagent/sft/trainer.py \
--resume-from output/sft/checkpoint-1500
# Resume from specific RAFT cycle
python malagent/training/raft_trainer.py \
--resume-from output/raft/cycle_3
Resume Behavior
When resuming RAFT from cycle N:
- Load model from
cycle_N/checkpoint/ - Load optimizer/scheduler states
- Restore RNG state for reproducibility
- Continue from cycle N+1
- Preserve previous cycle statistics
Crash Recovery
Detecting Incomplete Cycles
If training crashes mid-cycle:
# Check for incomplete cycles
ls -la output/raft/*/training_state.json
# Incomplete cycle won't have training_state.json
# Resume will restart from last complete cycle
Manual Recovery
If automatic resume fails:
from malagent.training import RAFTTrainer
trainer = RAFTTrainer(config_path="configs/raft_config.yaml")
# Force load from specific checkpoint
trainer.load_checkpoint("output/raft/cycle_2/checkpoint")
# Continue training
trainer.run(start_cycle=3, num_cycles=6)
Comparing Checkpoints
Evaluate Specific Cycle
# Evaluate cycle 3 checkpoint
python malagent/benchmark/evaluate.py \
--model output/raft/cycle_3/checkpoint \
--prompts data/eval_prompts.jsonl \
--samples 4
Compare Cycles
# Compare compile rates across cycles
for cycle in output/raft/cycle_*/; do
echo "=== $cycle ==="
jq '.compile_rate' "$cycle/metrics.json"
done
Best Checkpoint Selection
malagent maintains a best/ symlink to the best-performing cycle:
# Current best
ls -la output/raft/best
# -> cycle_6/checkpoint
# Use best for inference
python inference.py --model output/raft/best
Storage Management
Checkpoint Sizes
| Component | Size (7B model) |
|---|---|
| LoRA weights | ~500MB |
| Optimizer state | ~1GB |
| Samples (per cycle) | ~50-200MB |
| Full cycle | ~2GB |
| 6 cycles total | ~12GB |
Cleanup Old Checkpoints
# Keep only best and last 2 cycles
python scripts/cleanup_checkpoints.py \
--keep-best \
--keep-last 2 \
--output-dir output/raft
Compress for Storage
# Archive a completed training run
tar -czvf training_run_2024-01-15.tar.gz \
output/raft/best \
output/raft/statistics.json \
output/raft/*/metrics.json
Export for Deployment
Merge LoRA to Base Model
from peft import PeftModel
from transformers import AutoModelForCausalLM
# Load base model
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-7B")
# Load and merge LoRA
model = PeftModel.from_pretrained(base_model, "output/raft/best")
merged_model = model.merge_and_unload()
# Save merged model
merged_model.save_pretrained("output/merged_model")
Convert to GGUF (for llama.cpp)
# After merging
python llama.cpp/convert_hf_to_gguf.py \
output/merged_model \
--outfile output/malagent-7b.gguf \
--outtype q8_0