Why use multiple attention heads instead of one large attention operation?
Multiple heads let the model simultaneously attend to different types of relationships — syntactic, semantic, coreference, positional — within the same layer. A single head produces a single weighted mixture and can only represent one relational pattern per layer; splitting into h heads and projecting to lower dimensions gives h independent subspaces for pattern capture at the same total parameter cost.
How to think about it
Multi-head attention runs h attention functions in parallel, each in a lower-dimensional subspace:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W_O
where head_i = Attention(Q W_Qi, K W_Ki, V W_Vi) and each projection is into d_k = d_model / h dimensions.
Why this works better than one head at d_model:
- Each head learns its own
W_Q,W_K,W_V— its own notion of “what is a relevant query” and “what is a relevant key”. - Head 1 might specialise in subject-verb agreement; head 2 in coreference; head 3 in local adjacency. This differentiation emerges from training, not design.
- Because
d_k = d_model / h, the total computation and parameter count is approximately the same as one full-dimensional head — the multiple heads are not free, but they are not multiplicatively more expensive either.
The outputs of all heads are concatenated and linearly mixed by W_O ∈ ℝ^{hd_v × d_model} to produce the final representation.
import torch.nn as nn
mha = nn.MultiheadAttention(
embed_dim=512, # d_model
num_heads=8, # h=8, so d_k = d_v = 64
batch_first=True
)
out, weights = mha(query, key, value)