Batch size ↔ learning rate
Batch size and learning rate are a coupled pair, not two independent knobs. The linear scaling rule, warmup, and gradient accumulation — how to train big-batch on a small GPU.
What you'll learn
- Why batch size and learning rate must be tuned together (the linear scaling rule)
- Why large-batch training needs learning-rate warmup
- How gradient accumulation simulates a big batch on a small GPU
Before you start
A beginner tunes batch size and learning rate as if they were independent. They are not. Change one and the other’s best value shifts with it — and missing that coupling is why “I increased the batch size and training got worse” is such a common complaint. The reason is simple: the batch size sets how noisy each gradient is, and the learning rate sets how far you step on that gradient.
Bigger batch = less noisy gradient
Each batch estimates the true gradient from a sample. A small batch is a noisy
estimate; a large batch averages more examples, so its gradient is smoother — the
noise shrinks roughly as 1/√(batch size). Slide the batch size and watch the
descent go from jagged to smooth — and notice the steps-per-epoch and the scaled
learning rate move with it:
A smoother gradient is more trustworthy, so you can afford a bigger step. That’s the intuition behind the rule everyone uses.
The linear scaling rule
When you multiply the batch size by
k, multiply the learning rate byk.
Double the batch, double the learning rate. It’s an approximation, but it’s the standard starting point — it keeps the total distance traveled per epoch roughly constant as you scale up. (Some recipes use a square-root rule instead; linear is the common default for SGD-style training.)
Gradient accumulation: big batch on a small GPU
Want a batch of 512 but only 64 fit in memory? Accumulate gradients over
several forward/backward passes before stepping. Because backprop adds to
.grad (the same reason you call zero_grad()), running backward() 8 times
without stepping sums 8 mini-batches’ gradients — mathematically a batch of
8 × 64 = 512.
In PyTorch, the pattern is: don’t zero_grad() until after the accumulation, and
scale the loss so the sum behaves like an average:
opt.zero_grad()
for i, (xb, yb) in enumerate(loader):
loss = loss_fn(model(xb), yb) / accum_steps # scale so grads average
loss.backward() # ADDS to .grad
if (i + 1) % accum_steps == 0:
opt.step() # step once per big batch
opt.zero_grad()
Quick check
Quick check
Next
You can now train efficiently on the hardware you have. To go beyond one GPU, distributed training (DDP & FSDP) shows how the batch — and the model itself — gets split across devices.
Practice this in an interview
All questionsLarger batches give lower-variance gradient estimates, so they typically allow and often need a proportionally larger learning rate, while very high learning rates early in training can destabilize it. Warmup ramps the learning rate up from a small value over the first steps to avoid early divergence, then follows a decay schedule.
Larger batches give more accurate gradient estimates and enable higher GPU utilisation, but they tend to converge to sharper minima that generalise worse. Smaller batches introduce gradient noise that acts as implicit regularisation, helping the optimiser escape sharp minima and often finding flatter, better-generalising solutions — at the cost of slower wall-clock training per epoch.
A learning rate schedule changes the learning rate during training rather than keeping it fixed. Warmup starts with a very small LR and ramps it up over the first few hundred or thousand steps, preventing early large gradient updates from destabilising freshly initialised weights. After warmup, the LR is typically decayed — via cosine annealing, step decay, or linear decay — so the optimiser can settle into a sharp minimum.
Gradient accumulation runs several forward and backward passes without zeroing gradients, sums them, and only steps the optimizer after N micro-batches, simulating a larger effective batch size than fits in memory. It lets you train with large effective batches on limited GPU memory at the cost of more compute per update.