What is a learning rate schedule, and why is warmup important?
A learning rate schedule changes the learning rate during training rather than keeping it fixed. Warmup starts with a very small LR and ramps it up over the first few hundred or thousand steps, preventing early large gradient updates from destabilising freshly initialised weights. After warmup, the LR is typically decayed — via cosine annealing, step decay, or linear decay — so the optimiser can settle into a sharp minimum.
How to think about it
A fixed learning rate is almost never optimal: too high early on causes divergence; too high late in training prevents convergence to a good minimum. Schedules address both.
Why warmup?
At the start of training, the model’s weights are random and batch statistics are uninitialised. The adaptive second-moment estimate in Adam is also zero — its bias-correction term inflates effective LR unpredictably in the first steps. A linear warmup ramps LR from near-zero to the target value over W steps, letting moment estimates stabilise before full-speed updates.
For transformers, warmup is effectively mandatory: without it, the embedding and attention layers receive enormous gradients on step 1 and often diverge.
Common schedules after warmup
Cosine annealing decays LR smoothly to a floor (often 10% of peak), following a half-cosine curve. This is the default in most modern vision and language model recipes.
Step decay halves the LR at fixed epoch milestones (common in ResNet training).
Linear decay is simple and works well for fine-tuning large pretrained models.
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
model = torch.nn.Linear(512, 10)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
warmup_steps = 500
total_steps = 10_000
warmup = LinearLR(optimizer, start_factor=1e-3, end_factor=1.0, total_iters=warmup_steps)
cosine = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=1e-5)
schedule = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])
for step in range(total_steps):
optimizer.zero_grad()
# ... loss.backward() ...
optimizer.step()
schedule.step()