Table of Contents
Overview
GPU memory (VRAM) is often the bottleneck when training large AI models. A 7B parameter model needs ~14GB just for weights in FP16, and training requires 3-4x more for gradients, optimizer states, and activations.
The good news: you can dramatically reduce memory usage with the right techniques. This guide covers everything from simple flags to advanced distributed strategies.
Quick Wins (Try First)
- 1. Mixed Precision:
model.half()— 50% memory savings - 2. Gradient Checkpointing: One line — 60% savings
- 3. Smaller Batch Size: Reduce and use accumulation
- 4. FlashAttention: Huge savings for transformers
Where Does Memory Go?
Understanding memory breakdown helps you target the right optimizations:
Model Parameters
The actual weights of your model
Use smaller models, pruning, quantization
Gradients
Gradients computed during backprop
Gradient checkpointing, accumulation
Optimizer States
Adam stores 2 states per parameter
Use SGD, Adafactor, or 8-bit optimizers
Activations
Intermediate values for backprop
Gradient checkpointing, smaller batches
The Adam Optimizer Problem
Adam stores 2 additional states per parameter (momentum + variance). For a 7B model, that's 28GB just for optimizer states in FP32. Use 8-bit Adam or switch to SGD/Adafactor for huge savings.
Mixed Precision Training
Mixed precision uses FP16 (or BF16) for most operations while keeping FP32 for numerically sensitive parts. It's the easiest optimization with the biggest impact.
# PyTorch Automatic Mixed Precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
with autocast(): # FP16 forward pass
outputs = model(batch)
loss = criterion(outputs, targets)
scaler.scale(loss).backward() # Scaled backward
scaler.step(optimizer)
scaler.update()FP16 vs BF16
- • Wider hardware support
- • Needs loss scaling
- • Can have overflow issues
- • Same range as FP32
- • No loss scaling needed
- • Requires Ampere+ GPU
Gradient Checkpointing
Gradient checkpointing trades compute for memory. Instead of storing all activations for the backward pass, it recomputes them on-the-fly. This can reduce memory by 60-70% with only ~20% slowdown.
# Enable gradient checkpointing in transformers
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
# Enable checkpointing
model.gradient_checkpointing_enable()
# Now training uses ~60% less memoryWhen to Use
- Training large models on limited VRAM
- Fine-tuning when memory is tight
- When batch size is more important than speed
Trade-offs
- ~20-30% slower training
- More GPU compute used
- Not all layers benefit equally
FlashAttention
FlashAttention is a game-changer for transformer models. Standard attention has O(n²) memory complexity — FlashAttention reduces this to O(n) while also being 2-4x faster through better GPU memory access patterns.
# Using FlashAttention with transformers
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2", # Enable FlashAttention
device_map="auto"
)
# FlashAttention reduces memory from O(n²) to O(n)
# Also 2-4x faster for long sequencesDeepSpeed ZeRO
DeepSpeed ZeRO (Zero Redundancy Optimizer) partitions model states across GPUs, enabling training of models that don't fit on a single GPU.
| Stage | Partitions | Memory Reduction | Use Case |
|---|---|---|---|
| ZeRO-1 | Optimizer states | 4x | Multi-GPU, minimal overhead |
| ZeRO-2 | + Gradients | 8x | Most common choice |
| ZeRO-3 | + Parameters | 64x+ | Very large models |
# DeepSpeed ZeRO-2 configuration
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"reduce_scatter": true
},
"fp16": {
"enabled": true,
"loss_scale_window": 100
},
"gradient_accumulation_steps": 4,
"train_micro_batch_size_per_gpu": 2
}Batch Size Optimization
Batch size directly impacts memory. Use gradient accumulation to simulate larger effective batch sizes without the memory cost.
Gradient Accumulation Formula
Example: micro_batch=2, accumulation=8, gpus=4 → effective_batch = 64
Finding Optimal Batch Size
- 1. Start with batch_size=1
- 2. Double until OOM
- 3. Use 80% of max as safe value
- 4. Add accumulation for effective size
Memory vs Batch Size
Practical Tips
Start with Mixed Precision
Always enable FP16/BF16 first. It's free performance and memory.
model.half() # or use autocast()Use Gradient Accumulation
Simulate larger batches without more memory.
accumulation_steps = 8 # effective_batch = micro_batch * 8Monitor with nvidia-smi
Watch memory usage in real-time to find bottlenecks.
watch -n 0.5 nvidia-smiClear Cache Regularly
Free up fragmented memory between operations.
torch.cuda.empty_cache()Use torch.no_grad() for Inference
Disable gradient tracking when not training.
with torch.no_grad(): outputs = model(inputs)Profile Memory Usage
Find exactly where memory is being used.
torch.cuda.memory_summary()Technique Comparison
| Technique | Memory Saved | Speed | Complexity |
|---|---|---|---|
| Mixed Precision (FP16/BF16) | 50% | 2-3x | Low |
| Gradient Checkpointing | 60-70% | 0.8x (slower) | Low |
| FlashAttention | 5-20x (attention) | 2-4x | Medium |
| DeepSpeed ZeRO-2 | 8x | 1x | Medium |
| DeepSpeed ZeRO-3 | 64x+ | 0.9x | High |
| 8-bit Optimizers | 75% (optimizer) | 1x | Low |