What is the vanishing gradient problem and how do you fix it?
The short answer
Vanishing gradients occur when repeated multiplication of small derivatives during backpropagation drives gradients toward zero, starving early layers of learning signal. The main fixes are better activations (ReLU/GELU), residual connections, batch normalization, and careful weight initialization.
How to think about it
During backpropagation, the gradient of the loss w.r.t. layer l’s weights involves a product of Jacobians from every layer above it:
dL/dW^[l] ∝ (dL/da^[L]) · J^[L] · J^[L-1] · ... · J^[l+1]
If each Jacobian has spectral norm less than 1 — as sigmoid’s does, since its derivative is at most 0.25 — this product shrinks exponentially with depth. A 20-layer sigmoid net can have gradients of order 10^{-12} at the first layer.
Concrete example with sigmoid:
import torch
z = torch.tensor(3.0)
sig = torch.sigmoid(z)
grad = sig * (1 - sig) # ≈ 0.045 at z=3
# over 20 layers: 0.045^20 ≈ 2e-28
Fixes, roughly in order of impact:
- Switch to ReLU / GELU — derivative is 1 for positive inputs, so the gradient does not shrink on the active side.
- Residual connections (He et al. 2016) — add a skip from
xtoF(x)+x. Gradient flows straight through the addition gate with derivative 1, bypassing the multiplicative chain. - Batch / Layer normalization — keeps pre-activation magnitudes in a healthy range, preventing saturation.
- He / Xavier initialization — ensures variance is preserved at initialization, so the chain product starts near 1 rather than already small.
- Gradient clipping — handles the exploding variant and stabilizes training.
# Residual block in PyTorch
class ResBlock(nn.Module):
def forward(self, x):
return x + self.layers(x) # skip connection