Multi-GPU: DDP & FSDP
When the model fits on one GPU, replicate it (DDP). When it doesn't, shard it (FSDP). The memory math behind the decision, and how data parallelism actually scales.
What you'll learn
- The ~16-bytes-per-parameter memory rule for mixed-precision Adam
- Data parallelism (DDP) — replicate the model, split the batch, all-reduce the grads
- When you must shard with FSDP, and the DDP-vs-FSDP decision rule
Before you start
At some point one GPU isn’t enough — either training is too slow, or the model simply doesn’t fit in memory. Multi-GPU training is now a baseline skill, and the whole subject comes down to one question you can answer with arithmetic: does the model’s training state fit on a single GPU? The answer picks your strategy.
First, the memory math
People underestimate training memory because they only count the weights. Training a model with the Adam optimizer in mixed precision costs roughly 16 bytes per parameter:
bf16 weights 2 bytes
bf16 gradients 2 bytes
fp32 master weights 4 bytes ← the "master copy" for stable updates
Adam moment m 4 bytes
Adam moment v 4 bytes
--------
16 bytes / parameter
So a 7B-parameter model needs about 7e9 × 16 = 112 GB just for model state —
already over an 80GB GPU, before activations. That single number drives the whole
decision. Try it:
DDP — replicate the model, split the batch
Distributed Data Parallel is the workhorse when the model does fit. Every GPU holds a full copy of the model; you split each batch across the GPUs, each computes gradients on its shard, and then an all-reduce averages the gradients so every replica stays in sync before the optimizer step.
DDP gives near-linear speedup but no memory savings — each GPU still holds the entire model state. It’s the right tool when the model fits and you just want to train faster on a bigger effective batch.
# torchrun --nproc_per_node=4 train.py
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group("nccl")
model = DDP(model.to(local_rank), device_ids=[local_rank])
# ... the normal training loop; DDP all-reduces gradients during backward()
FSDP — shard the model when it won’t fit
When the full state exceeds one GPU, Fully Sharded Data Parallel splits the
parameters, gradients, and optimizer states across the GPUs. Each GPU stores
only its 1/N shard and gathers the full weights for a layer only momentarily,
during that layer’s forward/backward, then frees them. Per-GPU memory drops
roughly linearly with the number of GPUs — which is exactly what lets you train a
70B model that could never fit on one card.
Quick check
Quick check
Next
That completes the training half of the track. Next we move to the architecture that everything modern is built on — starting with how text becomes tensors in tokenization, then self-attention and the transformer block.
Practice this in an interview
All questionsDistributed Data Parallel (DDP) replicates the full model on every GPU and synchronizes gradients each step, which is simple but requires the whole model to fit on one GPU. Fully Sharded Data Parallel (FSDP) shards parameters, gradients, and optimizer states across GPUs and gathers them on demand, drastically cutting per-GPU memory so you can train much larger models at the cost of extra communication.
GPUs execute tensor operations efficiently only when the batch dimension is large enough to saturate all CUDA cores. Dynamic batching collects individual requests arriving within a short window and fuses them into a single GPU call, dramatically improving throughput and cost efficiency without sacrificing per-request latency beyond the configured wait threshold.
Neural network training is dominated by large matrix multiplications that are embarrassingly parallel. GPUs have thousands of small cores optimised for this exact operation, whereas CPUs have tens of powerful cores optimised for low-latency sequential logic. The throughput difference is 10–100x for typical DL workloads.
Mixed precision training stores weights and activations in float16 (or bfloat16) for forward/backward passes while keeping a float32 master copy of weights for the update step. This halves memory usage and delivers 2–4x throughput on modern tensor cores, with negligible accuracy loss when used with loss scaling.