Weight initialization
Set the starting weights too big and the signal explodes; too small and it vanishes. Why Xavier and He initialization are the scales that let deep nets train at all.
What you'll learn
- Why activation variance compounds with depth and breaks naive initialization
- Why initializing all weights to zero never learns (the symmetry problem)
- When to use Xavier/Glorot vs Kaiming/He, and the residual-branch trick in transformers
Before you start
In the training-loop demo we started one neuron
at w = 0, b = 0 and it learned fine. Try that in a fifty-layer network and it
will never train — not slowly, never. Initialization is the least glamorous
choice in deep learning and one of the most decisive: get the scale wrong and
your gradients are dead before the first step.
The problem: variance compounds with depth
Picture a signal flowing forward through layers. Each layer multiplies it by a weight matrix. A rough rule for the variance of the output of one layer:
var(output) ≈ fan_in × var(weights) × var(input)
where fan_in is the number of inputs to the layer (often hundreds or
thousands). Now stack many layers. That multiplier applies every layer, so the
variance grows or shrinks exponentially with depth:
- Weights too big →
fan_in × var(weights) > 1→ variance explodes → activations saturate, gradients blow up to NaN. - Weights too small → multiplier
< 1→ variance decays toward zero → activations vanish, gradients vanish, nothing updates.
The signal has to thread a needle: each layer must roughly preserve the variance. Watch it happen — push a unit-variance signal through 12 layers and see which initialization keeps it alive:
The fix: scale to fan_in
The whole game is choosing var(weights) so the per-layer multiplier is ~1. Two
schemes dominate, and which one is correct depends on the activation:
- Xavier / Glorot —
var(weights) = 1 / fan_in(often2 / (fan_in + fan_out)). Designed for tanh / sigmoid, whose linear regime preserves variance. - Kaiming / He —
var(weights) = 2 / fan_in. Designed for ReLU. The factor of 2 compensates for ReLU zeroing out half its inputs, which otherwise halves the variance every layer.
The rule of thumb: ReLU → He, tanh/sigmoid → Xavier. PyTorch’s defaults are reasonable, but for ReLU nets you often set it explicitly:
import torch.nn as nn
layer = nn.Linear(512, 512)
nn.init.kaiming_normal_(layer.weight, nonlinearity="relu") # He init
nn.init.zeros_(layer.bias) # bias starts at 0
See the variance survive (or not)
Mirror the visualizer with real numbers: propagate a signal through 15 ReLU layers and print the activation std at each. Naive init explodes; He init holds steady near 1.
The naive column races off to enormous numbers; the He column stays near 1 all the way down. That difference is the difference between a network that trains and one that returns NaN on step one.
Quick check
Quick check
Next
Good init gives the gradients a fighting start. But depth can still strangle them mid-training — vanishing & exploding gradients covers how to diagnose that with gradient norms and fix it with clipping and normalization.
Practice this in an interview
All questionsPoor initialization causes the variance of activations to either explode or collapse across layers, triggering vanishing or exploding gradients before training even begins. Xavier initialization targets variance preservation for saturating activations; He initialization corrects for the halved variance caused by ReLU zeroing negative inputs.
Both scale initial weights based on layer fan-in and fan-out to keep activation and gradient variance stable across layers. Xavier (Glorot) assumes a symmetric activation like tanh or sigmoid, while He initialization uses a larger variance tuned for ReLU-family activations, which zero out half their inputs. Use Xavier with tanh or sigmoid and He with ReLU or LeakyReLU.
Vanishing gradients occur when repeated multiplication of small derivatives during backpropagation drives gradients toward zero, starving early layers of learning signal. The main fixes are better activations (ReLU/GELU), residual connections, batch normalization, and careful weight initialization.
Larger batches give more accurate gradient estimates and enable higher GPU utilisation, but they tend to converge to sharper minima that generalise worse. Smaller batches introduce gradient noise that acts as implicit regularisation, helping the optimiser escape sharp minima and often finding flatter, better-generalising solutions — at the cost of slower wall-clock training per epoch.