What causes overfitting in deep neural networks and how do you fight it?
The short answer
Deep nets overfit when they memorize training examples rather than learning generalizable patterns — usually because the model has far more parameters than the signal in the data can constrain. The fix is a layered defence: regularization, data augmentation, early stopping, and architecture choices.
How to think about it
Overfitting in deep nets has a simple diagnosis: training loss keeps dropping while validation loss plateaus or rises. The model is fitting noise.
Root causes
- Too many parameters relative to labelled examples.
- Training too long without a stopping criterion.
- Insufficient data diversity.
Layered defences — use several, not one
| Technique | What it does | When to reach for it |
|---|---|---|
| Dropout | Randomly zeroes activations during training, forcing redundant representations | Dense/recurrent layers; less useful in BatchNorm-heavy CNNs |
| Weight decay (L2) | Penalises large weights in the loss | Almost always — set 1e-4 as a default |
| Data augmentation | Artificially expands dataset by flipping, cropping, colour jitter, Mixup | Vision and audio tasks |
| Early stopping | Halts training when val loss stops improving | Universal; save the checkpoint at best val loss |
| Reduce model size | Fewer layers or narrower layers | When simpler architecture matches task complexity |
| BatchNorm | Normalises activations; acts as mild regulariser | CNNs and deep MLPs |
import torch.nn as nn
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(p=0.3), # drop 30 % of activations
nn.Linear(256, 10),
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
Early stopping is often the single highest-leverage lever: once the validation curve flattens, every extra epoch buys training-set performance at the cost of generalisation.