Performance Tuning
Memory optimization and throughput
Memory Management
Strix Halo Unified Memory
The AMD Strix Halo’s 128GB unified memory is shared between CPU and GPU. Unlike discrete GPUs, memory pressure affects the entire system.
┌─────────────────────────────────────────┐
│ 128GB Unified Memory │
├─────────────────────────────────────────┤
│ Model Weights │ ~14GB (7B bf16) │
│ KV Cache │ ~2-8GB (varies) │
│ Optimizer States │ ~28GB (AdamW) │
│ Gradients │ ~14GB │
│ Activations │ ~4-20GB (varies) │
│ System/OS │ ~4GB │
├─────────────────────────────────────────┤
│ Available │ ~60GB headroom │
└─────────────────────────────────────────┘
Memory Cleanup
Force garbage collection between phases to prevent fragmentation:
import gc
import torch
def cleanup_memory():
"""Aggressive memory cleanup between operations."""
gc.collect()
gc.collect() # Second pass catches cycles
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Use between generation and training
samples = generate_samples(prompts)
cleanup_memory()
train_on_samples(samples)
Memory Monitoring
# Real-time memory usage
watch -n 1 "free -h && echo '---' && rocm-smi --showmeminfo vram"
# Memory pressure indicator
cat /proc/meminfo | grep -E "(MemFree|MemAvailable|Buffers|Cached)"
Batch Size Optimization
Training Phase
| Batch Size | Gradient Accum | Effective Batch | Memory | Speed |
|---|---|---|---|---|
| 1 | 32 | 32 | ~45GB | Baseline |
| 2 | 16 | 32 | ~65GB | 1.4× |
| 4 | 8 | 32 | ~95GB | 1.8× |
| 8 | 4 | 32 | OOM | - |
Recommendation: Use batch_size=2 with gradient_accumulation=16 for best balance.
Generation Phase
| Batch Size | Memory | Tokens/Second | Notes |
|---|---|---|---|
| 1 | ~16GB | 45 | Safe but slow |
| 4 | ~25GB | 150 | Good balance |
| 8 | ~40GB | 250 | Default |
| 16 | ~70GB | 380 | If memory allows |
generation:
batch_size: 8 # Samples generated in parallel
max_new_tokens: 2048 # Max output length
temperature: 0.7 # Sampling temperature
Gradient Checkpointing
Trades compute for memory by recomputing activations during backward pass.
Impact
| Setting | Memory | Speed |
|---|---|---|
| Disabled | 100% | 100% |
| Enabled | ~55% | ~70% |
Configuration
training:
gradient_checkpointing: true
When to Use
- Enable when batch_size=1 still causes OOM
- Disable when memory headroom > 30GB for faster training
Flash Attention
Flash Attention reduces memory usage and increases speed for attention computation.
Verification
# Check flash attention is available
python -c "from flash_attn import flash_attn_func; print('Flash Attention: OK')"
# Check ROCm compatibility
python -c "import torch; print(f'ROCm: {torch.version.hip}')"
Memory Savings
| Sequence Length | Standard Attention | Flash Attention | Savings |
|---|---|---|---|
| 1024 | 4GB | 0.5GB | 87% |
| 2048 | 16GB | 1GB | 94% |
| 4096 | 64GB | 2GB | 97% |
Enabling
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
KV Cache Optimization
For long sequence generation, KV cache can consume significant memory.
Static Cache
Pre-allocates cache for consistent memory usage:
model.generation_config.cache_implementation = "static"
model.generation_config.max_length = 4096
Sliding Window
For very long sequences (if model supports):
model.generation_config.cache_implementation = "sliding_window"
model.generation_config.sliding_window = 2048
Throughput Benchmarks
Phase Timing (7B Model, 569 Prompts × 8 Samples)
| Phase | Duration | Rate | GPU Util |
|---|---|---|---|
| Generation | ~45 min | 100 samples/min | 85% |
| Verification | ~90 min | 50 samples/min | 10% |
| Training | ~30 min | 30 samples/min | 95% |
| Full Cycle | ~3 hours | - | - |
Bottleneck Analysis
Generation ████████████░░░░░░░░ 25% (GPU-bound)
Verification ████████████████████████████░░░░ 50% (IO-bound)
Training ███████████░░░░░░░░░░ 25% (GPU-bound)
Key insight: Verification (SSH compilation, Elastic polling) is the bottleneck. Consider:
- Parallel compilation workers
- Longer Elastic poll intervals
- MVR mode for faster iteration
Compiler Optimization
MSVC Parallel Compilation
Enable parallel compilation on DEVBOX:
compiler:
flags:
- "/MP" # Multi-process compilation
- "/O2" # Optimize for speed
- "/GL" # Whole program optimization
Compilation Worker Pool
verification:
parallel_workers: 4 # Concurrent compilations
compilation_timeout: 30 # Seconds per sample
Profiling
GPU Profiling
# ROCm profiler
rocprof --stats python malagent/training/raft_trainer.py --cycles 1
# Memory timeline
rocm-smi --showmeminfo vram --loop 1 > memory_trace.txt &
Python Profiling
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
# ... training code ...
profiler.disable()
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20)
Configuration Template (Optimized)
# configs/raft_optimized.yaml
model:
path: "output/sft/final"
torch_dtype: "bfloat16"
attn_implementation: "flash_attention_2"
generation:
batch_size: 8
max_new_tokens: 2048
do_sample: true
temperature: 0.7
top_p: 0.95
training:
batch_size: 2
gradient_accumulation_steps: 16
learning_rate: 5e-5
gradient_checkpointing: true
bf16: true
optim: "adamw_torch_fused" # Faster optimizer
verification:
parallel_workers: 4
timeout: 30
memory:
cleanup_between_phases: true
gc_collect_generations: 2