Performance Tuning

Memory optimization and throughput

Memory Management

Strix Halo Unified Memory

The AMD Strix Halo’s 128GB unified memory is shared between CPU and GPU. Unlike discrete GPUs, memory pressure affects the entire system.

┌─────────────────────────────────────────┐
│           128GB Unified Memory          │
├─────────────────────────────────────────┤
│  Model Weights    │  ~14GB (7B bf16)    │
│  KV Cache         │  ~2-8GB (varies)    │
│  Optimizer States │  ~28GB (AdamW)      │
│  Gradients        │  ~14GB              │
│  Activations      │  ~4-20GB (varies)   │
│  System/OS        │  ~4GB               │
├─────────────────────────────────────────┤
│  Available        │  ~60GB headroom     │
└─────────────────────────────────────────┘

Memory Cleanup

Force garbage collection between phases to prevent fragmentation:

import gc
import torch

def cleanup_memory():
    """Aggressive memory cleanup between operations."""
    gc.collect()
    gc.collect()  # Second pass catches cycles
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

# Use between generation and training
samples = generate_samples(prompts)
cleanup_memory()
train_on_samples(samples)

Memory Monitoring

# Real-time memory usage
watch -n 1 "free -h && echo '---' && rocm-smi --showmeminfo vram"

# Memory pressure indicator
cat /proc/meminfo | grep -E "(MemFree|MemAvailable|Buffers|Cached)"

Batch Size Optimization

Training Phase

Batch SizeGradient AccumEffective BatchMemorySpeed
13232~45GBBaseline
21632~65GB1.4×
4832~95GB1.8×
8432OOM-

Recommendation: Use batch_size=2 with gradient_accumulation=16 for best balance.

Generation Phase

Batch SizeMemoryTokens/SecondNotes
1~16GB45Safe but slow
4~25GB150Good balance
8~40GB250Default
16~70GB380If memory allows
generation:
  batch_size: 8           # Samples generated in parallel
  max_new_tokens: 2048    # Max output length
  temperature: 0.7        # Sampling temperature

Gradient Checkpointing

Trades compute for memory by recomputing activations during backward pass.

Impact

SettingMemorySpeed
Disabled100%100%
Enabled~55%~70%

Configuration

training:
  gradient_checkpointing: true

When to Use

  • Enable when batch_size=1 still causes OOM
  • Disable when memory headroom > 30GB for faster training

Flash Attention

Flash Attention reduces memory usage and increases speed for attention computation.

Verification

# Check flash attention is available
python -c "from flash_attn import flash_attn_func; print('Flash Attention: OK')"

# Check ROCm compatibility
python -c "import torch; print(f'ROCm: {torch.version.hip}')"

Memory Savings

Sequence LengthStandard AttentionFlash AttentionSavings
10244GB0.5GB87%
204816GB1GB94%
409664GB2GB97%

Enabling

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

KV Cache Optimization

For long sequence generation, KV cache can consume significant memory.

Static Cache

Pre-allocates cache for consistent memory usage:

model.generation_config.cache_implementation = "static"
model.generation_config.max_length = 4096

Sliding Window

For very long sequences (if model supports):

model.generation_config.cache_implementation = "sliding_window"
model.generation_config.sliding_window = 2048

Throughput Benchmarks

Phase Timing (7B Model, 569 Prompts × 8 Samples)

PhaseDurationRateGPU Util
Generation~45 min100 samples/min85%
Verification~90 min50 samples/min10%
Training~30 min30 samples/min95%
Full Cycle~3 hours--

Bottleneck Analysis

Generation ████████████░░░░░░░░ 25%  (GPU-bound)
Verification ████████████████████████████░░░░ 50%  (IO-bound)
Training ███████████░░░░░░░░░░ 25%  (GPU-bound)

Key insight: Verification (SSH compilation, Elastic polling) is the bottleneck. Consider:

  • Parallel compilation workers
  • Longer Elastic poll intervals
  • MVR mode for faster iteration

Compiler Optimization

MSVC Parallel Compilation

Enable parallel compilation on DEVBOX:

compiler:
  flags:
    - "/MP"        # Multi-process compilation
    - "/O2"        # Optimize for speed
    - "/GL"        # Whole program optimization

Compilation Worker Pool

verification:
  parallel_workers: 4     # Concurrent compilations
  compilation_timeout: 30  # Seconds per sample

Profiling

GPU Profiling

# ROCm profiler
rocprof --stats python malagent/training/raft_trainer.py --cycles 1

# Memory timeline
rocm-smi --showmeminfo vram --loop 1 > memory_trace.txt &

Python Profiling

import cProfile
import pstats

profiler = cProfile.Profile()
profiler.enable()

# ... training code ...

profiler.disable()
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20)

Configuration Template (Optimized)

# configs/raft_optimized.yaml
model:
  path: "output/sft/final"
  torch_dtype: "bfloat16"
  attn_implementation: "flash_attention_2"

generation:
  batch_size: 8
  max_new_tokens: 2048
  do_sample: true
  temperature: 0.7
  top_p: 0.95

training:
  batch_size: 2
  gradient_accumulation_steps: 16
  learning_rate: 5e-5
  gradient_checkpointing: true
  bf16: true
  optim: "adamw_torch_fused"  # Faster optimizer

verification:
  parallel_workers: 4
  timeout: 30

memory:
  cleanup_between_phases: true
  gc_collect_generations: 2