datarekha
Deep Learning Medium Asked at GoogleAsked at OpenAIAsked at MetaAsked at Anthropic

Why use multiple attention heads instead of one large attention operation?

The short answer

Multiple heads let the model simultaneously attend to different types of relationships — syntactic, semantic, coreference, positional — within the same layer. A single head produces a single weighted mixture and can only represent one relational pattern per layer; splitting into h heads and projecting to lower dimensions gives h independent subspaces for pattern capture at the same total parameter cost.

How to think about it

Multi-head attention runs h attention functions in parallel, each in a lower-dimensional subspace:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W_O

where head_i = Attention(Q W_Qi, K W_Ki, V W_Vi) and each projection is into d_k = d_model / h dimensions.

Why this works better than one head at d_model:

  • Each head learns its own W_Q, W_K, W_V — its own notion of “what is a relevant query” and “what is a relevant key”.
  • Head 1 might specialise in subject-verb agreement; head 2 in coreference; head 3 in local adjacency. This differentiation emerges from training, not design.
  • Because d_k = d_model / h, the total computation and parameter count is approximately the same as one full-dimensional head — the multiple heads are not free, but they are not multiplicatively more expensive either.

The outputs of all heads are concatenated and linearly mixed by W_O ∈ ℝ^{hd_v × d_model} to produce the final representation.

import torch.nn as nn

mha = nn.MultiheadAttention(
    embed_dim=512,   # d_model
    num_heads=8,     # h=8, so d_k = d_v = 64
    batch_first=True
)
out, weights = mha(query, key, value)
Learn it properly Multi-head attention

Keep practising

All Deep Learning questions

Explore further

Skip to content