Why do we scale by sqrt(d_k) in scaled dot-product attention?
For large key dimensions, the dot products between query and key vectors grow in magnitude proportionally to d_k, pushing the softmax into regions with very small gradients. Dividing by sqrt(d_k) keeps the pre-softmax scores at unit variance regardless of dimension, stabilising training.
How to think about it
Assume queries and keys are independently sampled from zero-mean, unit-variance distributions. A dot product q · k = sum_{i=1}^{d_k} q_i k_i has:
- Mean: 0
- Variance:
d_k(sum ofd_kindependent terms each with variance 1)
So the standard deviation of the dot product is sqrt(d_k). For d_k = 64, scores have std ~8; for d_k = 512, std ~22.
When scores are large in magnitude, softmax concentrates nearly all probability mass on the maximum value:
softmax([0, 10, 0, 0]) ≈ [0, 1, 0, 0]
This is effectively a one-hot, so the gradient of the softmax is near zero everywhere — the model stops learning. Dividing by sqrt(d_k) returns the scores to roughly unit variance before softmax:
Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V
The divisor is sqrt(d_k), not d_k, because variance scales linearly — dividing variance by d_k is equivalent to dividing standard deviation by sqrt(d_k).
import torch, torch.nn.functional as F, math
d_k = 64
scores = torch.randn(8, 32, 32) * math.sqrt(d_k) # unscaled: high variance
scaled = scores / math.sqrt(d_k) # restored to ~unit std
weights = F.softmax(scaled, dim=-1)