datarekha
Deep Learning Medium Asked at GoogleAsked at OpenAIAsked at MetaAsked at Anthropic

What roles do residual connections and layer normalisation play in transformer training?

The short answer

Residual connections give gradients a direct path from the loss to every layer, preventing degradation with depth. Layer normalisation stabilises activations within each token's representation independently of batch size and sequence length, enabling stable training at large depth and with the variable-length sequences typical in NLP.

How to think about it

Residual connections (skip connections)

Each sub-layer in a transformer is wrapped as:

output = LayerNorm(x + SubLayer(x))

The + x term means the gradient of the loss with respect to x includes a direct additive term of 1 from the identity path, regardless of what SubLayer does. In a 24-layer transformer, gradients from the loss reach layer 1 through 48 residual shortcuts without needing to traverse 48 non-linear Jacobians in series. This is the same mechanism that made ResNets trainable at 100+ layers in vision.

Layer normalisation

LayerNorm(x) = gamma * (x - mu) / (sigma + eps) + beta

where mu and sigma are computed across the d_model feature dimension for a single token. Compare to batch normalisation, which normalises across the batch dimension:

AspectBatchNormLayerNorm
Normalises overBatchFeature dimension per token
Batch-size dependenceYes (unstable at size 1)None
Sequence-length dependenceYesNone
Works at inference with size-1 batchNo (needs running stats)Yes

Transformers for text routinely decode one token at a time (batch size 1); BatchNorm degenerates there. LayerNorm has no such constraint.

Pre-LN vs Post-LN: the original paper used Post-LN (normalise after residual add). Modern large models overwhelmingly use Pre-LN (normalise before the sub-layer), which improves gradient flow early in training and reduces the need for careful learning-rate warm-up.

Learn it properly The Transformer Architecture

Keep practising

All Deep Learning questions

Explore further

Skip to content