What is batch normalisation, and why does it help training?
Batch normalisation normalises each feature across the mini-batch to zero mean and unit variance, then applies learnable scale and shift parameters. It stabilises internal activation distributions — reducing internal covariate shift — which allows higher learning rates, reduces dependence on careful weight initialisation, and provides mild regularisation through the noise in batch statistics.
How to think about it
Batch normalisation (BN) was introduced by Ioffe and Szegedy in 2015 and became a staple of deep CNN training. It dramatically accelerated training by allowing much higher learning rates.
What it computes
For a mini-batch of activations {x₁, …, x_m} at a given layer:
μ_B = (1/m) Σ x_i # batch mean
σ²_B = (1/m) Σ (x_i − μ_B)² # batch variance
x̂_i = (x_i − μ_B) / √(σ²_B + ε) # normalise
y_i = γ·x̂_i + β # scale and shift (learnable)
The learnable parameters γ and β restore the representational power that normalisation removes — the network can learn to undo BN if that is optimal.
Training vs inference
During training, BN uses the mean and variance of the current mini-batch. It simultaneously updates running estimates of mean and variance via exponential moving average.
During inference, those running estimates replace the batch statistics so results are deterministic and batch-independent.
import torch.nn as nn
# BatchNorm2d for convolutional feature maps (normalises over N, H, W per channel)
bn = nn.BatchNorm2d(num_features=64, momentum=0.1, eps=1e-5)
model = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
bn,
nn.ReLU(),
)
model.train() # uses batch stats, updates running estimates
model.eval() # uses running estimates — REQUIRED before inference
Why it helps
- Higher LR tolerance: normalised inputs prevent activations from exploding or collapsing, so larger update steps are safe.
- Reduced sensitivity to weight init: even poorly initialised layers are quickly brought to a stable range.
- Mild regularisation: per-batch noise in
μ_Bandσ²_Bacts like data augmentation, often letting you reduce dropout. - Faster convergence: fewer epochs needed to reach the same validation accuracy.
Alternatives
BatchNorm is unsuitable for small batches, sequence models, or single-sample inference. Use LayerNorm (normalises across features, standard in transformers), GroupNorm (splits channels into groups, good for detection), or InstanceNorm (per-sample, used in style transfer).