Vanishing & exploding gradients
Backprop multiplies a long chain of numbers. Get them slightly below 1 and gradients vanish; slightly above and they explode. How to diagnose it with gradient norms and fix it with clipping, init, and normalization.
What you'll learn
- Why depth turns the chain rule into an exponential — vanishing or exploding gradients
- How to diagnose it by watching gradient norms per layer
- The fixes: gradient clipping, good init, ReLU, residuals, and normalization
Before you start
Backprop computes a node’s gradient by
multiplying the downstream gradient by a local derivative — over and over, once
per layer, all the way back to the input. That repeated multiplication is the
whole danger. Multiply twenty numbers that are each 0.5 and you get
0.5²⁰ ≈ 0.000001. Multiply twenty numbers that are each 1.5 and you get
1.5²⁰ ≈ 3325. The chain rule turns depth into an exponential, and the
gradient either vanishes to nothing or explodes to NaN.
See it happen, layer by layer
Run a real backward pass through 15 layers and watch the gradient norm at each
one. Gradients enter at the output and flow left toward the input. With
sigmoid, whose derivative never exceeds 0.25, the early layers are starved
of gradient — they barely learn. Push the weight scale up and the gradient
explodes the other way:
The key reading: the input-side layers (left) are where gradients are weakest. That is why, before modern tricks, deep networks “couldn’t train their early layers” — the learning signal evaporated before it reached them.
Vanishing: why it happens, how to fix it
The per-layer multiplier is roughly ‖W‖ × (activation derivative). Two things
make it shrink below 1:
- Saturating activations. Sigmoid and tanh flatten out for large inputs;
their derivative goes to ~0 there. ReLU’s derivative is exactly
1for positive inputs — it doesn’t shrink the gradient. This is the single biggest reason ReLU replaced sigmoid in deep nets. - Too-small weights. Covered in weight init — the wrong scale decays the signal (and its gradient) every layer.
The modern toolkit that keeps gradients alive through hundreds of layers:
- ReLU-family activations (ReLU, GELU, SiLU) — non-saturating.
- Good initialization — He/Xavier so the chain starts near 1.
- Residual connections —
x + f(x)gives the gradient a direct path that skips the multiplications (the reason ResNets and transformers go deep). - Normalization (BatchNorm, LayerNorm, RMSNorm) — rescales activations each layer so the chain can’t drift far from 1.
Exploding: clip the gradient
Exploding gradients are the opposite — common in RNNs and with too-large weights or learning rates. The standard fix is gradient clipping: before the optimizer step, if the total gradient norm exceeds a threshold, rescale the whole gradient down to that threshold. Direction is preserved; only the magnitude is capped.
In PyTorch this is one line, placed after backward() and before step():
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Quick check
Quick check
Next
You can now keep a deep net’s gradients alive and bounded. Next, the knobs that set the step size those gradients drive: optimizers, learning-rate schedules, and how batch size and learning rate interact.
Practice this in an interview
All questionsExploding gradients happen when the product of layer Jacobians has spectral norm greater than 1, causing gradients to grow exponentially with depth. Gradient clipping rescales the gradient norm to a maximum threshold before the weight update, preventing divergence without discarding gradient direction.
Vanishing gradients occur when repeated multiplication of small derivatives during backpropagation drives gradients toward zero, starving early layers of learning signal. The main fixes are better activations (ReLU/GELU), residual connections, batch normalization, and careful weight initialization.
Gradient clipping caps the magnitude of gradients (by value or by global norm) before the optimizer step, preventing exploding gradients that cause unstable or diverging training. It is especially useful in RNNs and transformers, where a single large update can destabilize learning.
Vanishing gradients occur when gradients shrink toward zero as they propagate back through many layers, so early layers learn extremely slowly or not at all; it is common with sigmoid or tanh activations in deep networks. Mitigations include ReLU-family activations, residual/skip connections, batch or layer normalization, careful initialization, and gated architectures like LSTMs.