How do you handle severe class imbalance when training a deep learning model?
Class imbalance causes the model to exploit the majority class and ignore the minority. The main levers are loss reweighting, oversampling or undersampling, focal loss, and using the right evaluation metric — accuracy is useless; use F1, precision-recall AUC, or MCC.
How to think about it
A fraud detection dataset with 0.1 % positive rate will produce a model that predicts “not fraud” for every input and achieves 99.9 % accuracy. This is not a model — it is a tautology.
Technique 1 — Loss reweighting
Assign higher loss weight to the minority class so that misclassifying a rare positive hurts more.
import torch
import torch.nn as nn
# 100 negatives for every 1 positive → weight positives 100x
pos_weight = torch.tensor([100.0])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
For multiclass, pass weight to nn.CrossEntropyLoss.
Technique 2 — Oversampling / undersampling
- Oversample the minority class (repeat samples or use SMOTE-style synthesis).
- Undersample the majority class to balance the batch.
- PyTorch’s
WeightedRandomSamplerhandles this at the DataLoader level cleanly.
from torch.utils.data import WeightedRandomSampler
# Give each sample a weight inversely proportional to its class frequency
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(dataset))
Technique 3 — Focal loss
Focal loss (Lin et al., RetinaNet) down-weights the loss on well-classified easy examples, forcing the model to focus on hard, often minority, samples.
FL(p_t) = -(1 - p_t)^gamma * log(p_t)
gamma=2 is a strong default for severe imbalance.
Technique 4 — Evaluation discipline
Track precision, recall, F1-score, or area under the precision-recall curve. Never use accuracy as a primary metric for imbalanced problems.
In practice, combine reweighting (always cheap) with oversampling or focal loss. Validate using stratified splits so the val set reflects real-world class distribution.