LearnFor DevelopersGPU Memory Optimization
Advanced Guide
15 min read

GPU Memory OptimizationTrain Larger Models on Limited Hardware

Running out of GPU memory? This guide covers every technique to reduce VRAM usage and train larger models — from mixed precision to DeepSpeed ZeRO. With code examples.

Up to 90%
Memory Saved
With all techniques
50%
Mixed Precision
Memory reduction
5-20x
FlashAttention
Less attention memory
64x+
DeepSpeed ZeRO-3
Multi-GPU scaling
G
Griddly Team
Updated December 2025

Overview

GPU memory (VRAM) is often the bottleneck when training large AI models. A 7B parameter model needs ~14GB just for weights in FP16, and training requires 3-4x more for gradients, optimizer states, and activations.

The good news: you can dramatically reduce memory usage with the right techniques. This guide covers everything from simple flags to advanced distributed strategies.

Quick Wins (Try First)

  • 1. Mixed Precision: model.half() — 50% memory savings
  • 2. Gradient Checkpointing: One line — 60% savings
  • 3. Smaller Batch Size: Reduce and use accumulation
  • 4. FlashAttention: Huge savings for transformers

Where Does Memory Go?

Understanding memory breakdown helps you target the right optimizations:

~25%

Model Parameters

The actual weights of your model

Use smaller models, pruning, quantization

~25%

Gradients

Gradients computed during backprop

Gradient checkpointing, accumulation

~35%

Optimizer States

Adam stores 2 states per parameter

Use SGD, Adafactor, or 8-bit optimizers

~15%

Activations

Intermediate values for backprop

Gradient checkpointing, smaller batches

The Adam Optimizer Problem

Adam stores 2 additional states per parameter (momentum + variance). For a 7B model, that's 28GB just for optimizer states in FP32. Use 8-bit Adam or switch to SGD/Adafactor for huge savings.

Mixed Precision Training

Mixed precision uses FP16 (or BF16) for most operations while keeping FP32 for numerically sensitive parts. It's the easiest optimization with the biggest impact.

50%
Memory Savings
2-3x
Speed Increase
~0%
Accuracy Loss
PyTorch AMP
Recommended
# PyTorch Automatic Mixed Precision
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()
    
    with autocast():  # FP16 forward pass
        outputs = model(batch)
        loss = criterion(outputs, targets)
    
    scaler.scale(loss).backward()  # Scaled backward
    scaler.step(optimizer)
    scaler.update()

FP16 vs BF16

FP16 (Float16)
  • • Wider hardware support
  • • Needs loss scaling
  • • Can have overflow issues
BF16 (BFloat16)
  • • Same range as FP32
  • • No loss scaling needed
  • • Requires Ampere+ GPU

Gradient Checkpointing

Gradient checkpointing trades compute for memory. Instead of storing all activations for the backward pass, it recomputes them on-the-fly. This can reduce memory by 60-70% with only ~20% slowdown.

Hugging Face Transformers
# Enable gradient checkpointing in transformers
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto"
)

# Enable checkpointing
model.gradient_checkpointing_enable()

# Now training uses ~60% less memory

When to Use

  • Training large models on limited VRAM
  • Fine-tuning when memory is tight
  • When batch size is more important than speed

Trade-offs

  • ~20-30% slower training
  • More GPU compute used
  • Not all layers benefit equally

FlashAttention

FlashAttention is a game-changer for transformer models. Standard attention has O(n²) memory complexity — FlashAttention reduces this to O(n) while also being 2-4x faster through better GPU memory access patterns.

5-20x
Less Attention Memory
2-4x
Faster
64K+
Context Length
FlashAttention 2
Must Have
# Using FlashAttention with transformers
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",  # Enable FlashAttention
    device_map="auto"
)

# FlashAttention reduces memory from O(n²) to O(n)
# Also 2-4x faster for long sequences

DeepSpeed ZeRO

DeepSpeed ZeRO (Zero Redundancy Optimizer) partitions model states across GPUs, enabling training of models that don't fit on a single GPU.

StagePartitionsMemory ReductionUse Case
ZeRO-1Optimizer states4xMulti-GPU, minimal overhead
ZeRO-2+ Gradients8xMost common choice
ZeRO-3+ Parameters64x+Very large models
DeepSpeed Config (ZeRO-2)
# DeepSpeed ZeRO-2 configuration
{
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "reduce_scatter": true
    },
    "fp16": {
        "enabled": true,
        "loss_scale_window": 100
    },
    "gradient_accumulation_steps": 4,
    "train_micro_batch_size_per_gpu": 2
}

Batch Size Optimization

Batch size directly impacts memory. Use gradient accumulation to simulate larger effective batch sizes without the memory cost.

Gradient Accumulation Formula

effective_batch = micro_batch × accumulation_steps × num_gpus

Example: micro_batch=2, accumulation=8, gpus=4 → effective_batch = 64

Finding Optimal Batch Size

  1. 1. Start with batch_size=1
  2. 2. Double until OOM
  3. 3. Use 80% of max as safe value
  4. 4. Add accumulation for effective size

Memory vs Batch Size

Batch 1~8GB
Batch 4~14GB
Batch 8~22GB
Batch 16~38GB (OOM on 24GB)

Practical Tips

Start with Mixed Precision

Always enable FP16/BF16 first. It's free performance and memory.

model.half() # or use autocast()

Use Gradient Accumulation

Simulate larger batches without more memory.

accumulation_steps = 8 # effective_batch = micro_batch * 8

Monitor with nvidia-smi

Watch memory usage in real-time to find bottlenecks.

watch -n 0.5 nvidia-smi

Clear Cache Regularly

Free up fragmented memory between operations.

torch.cuda.empty_cache()

Use torch.no_grad() for Inference

Disable gradient tracking when not training.

with torch.no_grad(): outputs = model(inputs)

Profile Memory Usage

Find exactly where memory is being used.

torch.cuda.memory_summary()

Technique Comparison

TechniqueMemory SavedSpeedComplexity
Mixed Precision (FP16/BF16)50%2-3x
Low
Gradient Checkpointing60-70%0.8x (slower)
Low
FlashAttention5-20x (attention)2-4x
Medium
DeepSpeed ZeRO-28x1x
Medium
DeepSpeed ZeRO-364x+0.9x
High
8-bit Optimizers75% (optimizer)1x
Low

Need More GPU Memory?

Access A100 80GB and H100 GPUs on Griddly Cloud. Train larger models without memory constraints — at 70% less than AWS.