What are the concrete reasons transformers outperform RNNs on most sequence tasks?
Transformers win on three axes: parallelism (no sequential dependency lets all positions train simultaneously on GPUs), path length (any two tokens interact in O(1) layers, not O(n) steps), and scalability (attention over longer contexts keeps improving with more compute, while RNN quality degrades with sequence length despite training costs).
How to think about it
The comparison across three critical dimensions:
| Dimension | RNN / LSTM | Transformer |
|---|---|---|
| Training parallelism | Sequential — step t waits for t-1 | Fully parallel across all positions |
| Long-range dependency path | O(n) multiplicative steps | O(1) attention steps |
| Gradient flow | Vanishes / explodes over distance | Additive residuals; stable |
| Memory at inference | Fixed-size hidden state | Full KV cache (grows with context) |
| Context length scaling | Practically capped ~1k tokens | Scales to 128k+ with engineering |
Parallelism is the biggest practical win. Modern hardware (GPUs, TPUs) thrives on matrix multiplications that can be batched across all sequence positions at once. An RNN over length 1024 requires 1024 sequential matrix-vector products; a transformer requires one large batched matrix-matrix multiplication — the latter is 10–100× faster on hardware.
Path length determines learnability of long-range dependencies. For a transformer, the gradient between token 1 and token 1000 passes through at most 2N sub-layers (N encoder layers, each with add-and-norm shortcuts). For an RNN, it passes through 999 multiplicative Jacobian products.
Scalability. Transformers follow smooth scaling laws (Chinchilla): double the compute, get predictable improvement. RNNs plateau earlier because quality degrades with sequence length before you can scale.