datarekha
Deep Learning Medium Asked at NVIDIAAsked at GoogleAsked at MetaAsked at OpenAI

What is mixed precision training and why does it matter?

The short answer

Mixed precision training stores weights and activations in float16 (or bfloat16) for forward/backward passes while keeping a float32 master copy of weights for the update step. This halves memory usage and delivers 2–4x throughput on modern tensor cores, with negligible accuracy loss when used with loss scaling.

How to think about it

float32 uses 4 bytes per value. float16 uses 2 bytes. Halving the storage of activations — which dominate GPU memory during training — means either a 2x larger batch or a 2x larger model fits on the same hardware. Modern GPUs (A100, H100) run float16 matrix multiplications in specialised tensor cores at 2–4x the throughput of float32.

Why keep a float32 master copy?

Gradient updates are tiny (often 1e-7 or smaller). float16 underflows values below ~6e-5, silently zeroing out small gradient components. The optimizer step is therefore kept in float32 so precision is not lost during the weight update.

Loss scaling

Because float16 gradients underflow, PyTorch’s GradScaler multiplies the loss by a large scale factor before backprop (shifting gradient magnitudes into the representable range), then unscales before the optimizer step.

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()

    with autocast():                      # forward in fp16
        output = model(batch["input"])
        loss = criterion(output, batch["label"])

    scaler.scale(loss).backward()         # backward in fp16, scaled
    scaler.step(optimizer)                # unscale + update in fp32
    scaler.update()                       # adjust scale factor

bfloat16 vs float16

bfloat16 has the same exponent range as float32 (avoids underflow) but lower mantissa precision. On TPUs and newer NVIDIA GPUs it is often preferred because it eliminates the need for loss scaling entirely.

Practical impact

  • Training a GPT-2 sized model: ~40 % reduction in GPU memory, ~2x tokens/second.
  • Quality: typically within 0.1–0.3 % of float32 on standard benchmarks.
Learn it properly Mixed precision

Keep practising

All Deep Learning questions

Explore further

Skip to content