datarekha
Patterns June 2, 2026

Why we normalize: batch norm, intuitively

Internal covariate shift is not a subtle bug — it is the reason deep networks used to need weeks of careful babysitting, and batch norm is why that era quietly ended.

9 min read · by datarekha · deep-learningbatch-normlayer-normtraining-dynamicsneural-networks

In 2015, training a 20-layer network from scratch took days on a good GPU cluster, demanded careful weight initialization rituals, and required a learning rate low enough that you could go on holiday between meaningful updates. That same year, Google Brain published a technique that let researchers train the same depth of network at learning rates 10 to 14 times higher, in roughly a tenth of the wall-clock time. The technique was batch normalization. It did not change the architecture. It did not change the loss function. It changed where the activations lived.

That is the whole story, told in one paragraph. The rest is the intuition for why location matters that much.

The covariate shift problem, and why deep nets have it internally

In classical machine learning, covariate shift (the distribution of inputs changing between training and deployment) is treated as a data pipeline problem. Deep learning introduced a version of this problem that lives entirely inside the network, and it compounds with every layer.

Imagine a four-layer network in mid-training. Layer 1 is updating its weights. That update changes the distribution of layer 1’s output — the numbers fed into layer 2. Layer 2 was trained to expect a certain distribution of inputs, but now that distribution has shifted. So layer 2 updates its weights. That update shifts the distribution of layer 2’s outputs. Layer 3 sees a new distribution. And so on.

This is internal covariate shift: each layer’s input distribution changes throughout training because all the preceding layers are simultaneously updating. It is like trying to hit a moving target while the target is strapped to another moving target.

The consequence is not just slow training. It is pathological training. When activations in the early layers drift to large values, the sigmoid or tanh nonlinearities saturate — the gradient of sigmoid(x) near x = 50 is effectively zero. No useful gradient flows back. The vanishing gradient problem, so often blamed on architecture, is partly an activation-scale problem.

Go the other way — activations collapse toward zero — and you lose the nonlinearity entirely. Every layer starts to behave linearly. The expressive depth you paid for becomes decorative.

What “normalize” actually means for a layer

Batch normalization (proposed by Ioffe and Szegedy in 2015) operates on each activation dimension independently, over a mini-batch of training examples. For a single neuron, with a batch of m examples, it:

  1. Computes the batch mean and batch variance for that neuron’s activations across the m examples.
  2. Subtracts the mean and divides by the standard deviation — this is the normalize step. The result has mean 0 and variance 1.
  3. Scales and shifts by two learned parameters, gamma and beta. These let the network decide, through gradient descent, how much variance and offset is actually useful for this activation.

The learned gamma and beta are crucial. Pure normalization to zero mean and unit variance every layer would destroy representational power — the network could not learn that some activations should be large, or offset. The scale-and-shift step hands that control back to the optimizer, but in a stable, well-conditioned space.

The result is that each layer, regardless of what the previous layers are doing, sees activations that occupy a predictable neighborhood. The target stopped moving.

Activations across depth: without vs with batch normWithout normalizationWith batch normLayer 1Layer 2Layer 3Layer 4mean≈0mean≈4mean≈18mean≈64 → saturatedμ=0, σ=1μ=0, σ=1μ=0, σ=1μ=0, σ=1stablebanddrift

Without normalization, activations shift and widen with each layer until nonlinearities saturate. Batch norm re-centers each layer’s output to a consistent distribution, keeping every layer in a trainable regime.

Why this unlocks higher learning rates

The learning rate is a multiplier on the gradient before subtracting it from the weights. A high learning rate means large steps. Large steps in a poorly conditioned loss landscape — where gradients point wildly different directions depending on tiny changes in parameters — cause overshooting, oscillation, and divergence.

Batch norm smooths the loss landscape. With activations in a predictable range, the gradient through any one layer is not amplified or suppressed by the accident of where activations happen to live at this moment in training. The effective curvature of the loss surface becomes more isotropic (similar in all directions), and you can step confidently in the gradient direction without worrying that a slightly larger step tips you into a pathological region.

This is not metaphor. Santurkar et al. (2018) showed empirically that batch norm makes the Lipschitz constant of the loss — a measure of how much the gradient can change — significantly smaller. You can take bigger steps safely because the terrain is less treacherous.

Inference: a quiet inconvenience

Batch norm has a practical wrinkle at inference time. During training, mean and variance are computed from the current mini-batch. But at inference, you might feed a single example. The batch statistics of one example are meaningless.

The fix is a running average. During training, batch norm tracks an exponential moving average of batch means and variances. At inference time, those frozen running statistics replace the per-batch computation. The network behaves identically — it just uses the statistics it learned from seeing the training distribution, rather than recomputing on the fly.

This works fine for images, text classification, regression. It becomes fragile when the test distribution deviates substantially from training — a reminder that normalization is not magic, just sensible bookkeeping.

Why transformers use layer norm instead

Layer normalization (layer norm) is batch norm’s close cousin with one key structural difference: it normalizes over the feature dimension of a single example, rather than over the batch dimension across examples.

For a single training example, layer norm computes the mean and variance across all the neurons in that layer’s activation vector, then normalizes. No batch statistics involved.

Transformers adopted layer norm for a practical reason that has nothing to do with expressiveness: attention is not naturally batched in the way convolutions are. A transformer processes sequences of varying lengths, and the attention mechanism mixes information across positions in a single sequence. The concept of a “mini-batch” of positions to normalize over is poorly defined and sequence-length-dependent.

Layer norm sidesteps this entirely. Each token’s activation vector is independently normalized over its own features. You can run a single token or a thousand tokens through the same normalization code, and the statistics are always meaningful. This also makes layer norm substantially easier to implement in autoregressive generation, where you produce one token at a time.

There is a subtler reason too. In very deep residual networks, the residual stream (the running sum of activations that passes through a transformer) has a specific structure that batch norm, with its cross-example coupling, can subtly disrupt. Layer norm composes more cleanly with residual connections because it never mixes information between examples in a batch.

The real lesson: signal conditioning is infrastructure

There is a tendency to treat normalization as a hyperparameter choice — batch norm or layer norm, before the activation or after, with or without the affine parameters. That framing misses the point.

Normalization is infrastructure for signal quality. The network learns from gradients. Gradients are derivatives of a loss with respect to weights. Those derivatives depend on activation values. If activation values are in a regime where derivatives are informative — not saturated, not collapsed — the network learns. If they are not, the network stalls regardless of how clever your architecture is.

This is why batch norm was such a qualitative change in practice. It was not a new kind of layer. It was a guarantee about the signal quality a layer receives. With that guarantee in place, you could train deeper networks, use higher learning rates, be less paranoid about initialization, and get useful models out of a training run that would previously have diverged.

The analogy I keep returning to: a speaker system can be technically excellent, but if the input signal is clipping, the speakers reproduce noise. You fix the signal before the speaker, not the speaker itself. Batch norm fixes the signal at every stage of the network.

Practical intuitions to keep

If you are training a convolutional network on images, batch norm is the default. Use a batch size large enough that the batch statistics are stable — below 16 examples per GPU, batch norm starts to become unreliable because the per-batch mean and variance are noisy estimates. At very small batches, group norm (normalizes over subgroups of channels within a single example) is the standard alternative.

If you are training or fine-tuning a transformer, layer norm is already in the architecture — the original transformer paper used it, and every successor since has kept it. Pre-norm (normalizing before the attention and feed-forward sublayers rather than after) has become the default in modern models because it improves training stability at very large depth.

If you ever see a network failing to train — loss not moving, gradients vanishing, activations collapsing — check the activation statistics before you blame the architecture. The diagnosis usually lives in the distribution of what each layer is receiving, not in the choice of attention head count or hidden dimension.

The math behind normalization is not difficult. The insight is recognizing that deep networks are not just function approximators — they are signal processing chains, and signal conditioning is not optional.

Skip to content