What is gradient clipping, and when is it necessary?
Gradient clipping caps the norm (or per-element value) of gradients before the optimiser step, preventing any single update from being so large that it destabilises training. It is especially important in recurrent networks and transformers where gradients can explode across many time steps or attention heads, and in any network trained with a high learning rate on noisy data.
How to think about it
Gradient clipping is a simple, low-cost guard that prevents catastrophic weight updates caused by exploding gradients — one of the most common instabilities in deep model training.
Why gradients explode
In deep networks, gradients are products of many Jacobians. If any Jacobian has eigenvalues greater than 1 (common in RNNs over many timesteps), the product grows exponentially. A single outlier batch can then push weights far from a good solution, often unrecoverably.
Norm clipping vs value clipping
Norm clipping (most common): if the global gradient norm exceeds a threshold C, scale all gradients proportionally so the norm equals C. Directions are preserved; only magnitude is reduced.
Value clipping: clip each gradient element independently to [-C, C]. Simpler but distorts the gradient direction.
import torch
import torch.nn as nn
model = nn.LSTM(input_size=128, hidden_size=256, num_layers=2, batch_first=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for x, y in dataloader:
optimizer.zero_grad()
output, _ = model(x)
loss = criterion(output, y)
loss.backward()
# Clip before the optimiser step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Choosing the threshold
A useful diagnostic: log the gradient norm each step for the first few hundred steps on a stable run. Set max_norm to roughly the 95th percentile. Common defaults: 1.0 for LSTMs; 1.0 for transformers (GPT-2/3 used 1.0).
Does it slow learning?
On most steps the norm is below the threshold and clipping is a no-op. It only activates on outlier batches, so the average learning speed is unaffected. The cost is one extra norm computation per step.