What is mixed precision training and why does it matter?
Mixed precision training stores weights and activations in float16 (or bfloat16) for forward/backward passes while keeping a float32 master copy of weights for the update step. This halves memory usage and delivers 2–4x throughput on modern tensor cores, with negligible accuracy loss when used with loss scaling.
How to think about it
float32 uses 4 bytes per value. float16 uses 2 bytes. Halving the storage of activations — which dominate GPU memory during training — means either a 2x larger batch or a 2x larger model fits on the same hardware. Modern GPUs (A100, H100) run float16 matrix multiplications in specialised tensor cores at 2–4x the throughput of float32.
Why keep a float32 master copy?
Gradient updates are tiny (often 1e-7 or smaller). float16 underflows values below ~6e-5, silently zeroing out small gradient components. The optimizer step is therefore kept in float32 so precision is not lost during the weight update.
Loss scaling
Because float16 gradients underflow, PyTorch’s GradScaler multiplies the loss by a large scale factor before backprop (shifting gradient magnitudes into the representable range), then unscales before the optimizer step.
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
with autocast(): # forward in fp16
output = model(batch["input"])
loss = criterion(output, batch["label"])
scaler.scale(loss).backward() # backward in fp16, scaled
scaler.step(optimizer) # unscale + update in fp32
scaler.update() # adjust scale factor
bfloat16 vs float16
bfloat16 has the same exponent range as float32 (avoids underflow) but lower mantissa precision. On TPUs and newer NVIDIA GPUs it is often preferred because it eliminates the need for loss scaling entirely.
Practical impact
- Training a GPT-2 sized model: ~40 % reduction in GPU memory, ~2x tokens/second.
- Quality: typically within 0.1–0.3 % of float32 on standard benchmarks.