datarekha
Deep Learning Medium Asked at StripeAsked at GoogleAsked at AmazonAsked at Palantir

How do you handle severe class imbalance when training a deep learning model?

The short answer

Class imbalance causes the model to exploit the majority class and ignore the minority. The main levers are loss reweighting, oversampling or undersampling, focal loss, and using the right evaluation metric — accuracy is useless; use F1, precision-recall AUC, or MCC.

How to think about it

A fraud detection dataset with 0.1 % positive rate will produce a model that predicts “not fraud” for every input and achieves 99.9 % accuracy. This is not a model — it is a tautology.

Technique 1 — Loss reweighting

Assign higher loss weight to the minority class so that misclassifying a rare positive hurts more.

import torch
import torch.nn as nn

# 100 negatives for every 1 positive → weight positives 100x
pos_weight = torch.tensor([100.0])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

For multiclass, pass weight to nn.CrossEntropyLoss.

Technique 2 — Oversampling / undersampling

  • Oversample the minority class (repeat samples or use SMOTE-style synthesis).
  • Undersample the majority class to balance the batch.
  • PyTorch’s WeightedRandomSampler handles this at the DataLoader level cleanly.
from torch.utils.data import WeightedRandomSampler

# Give each sample a weight inversely proportional to its class frequency
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(dataset))

Technique 3 — Focal loss

Focal loss (Lin et al., RetinaNet) down-weights the loss on well-classified easy examples, forcing the model to focus on hard, often minority, samples.

FL(p_t) = -(1 - p_t)^gamma * log(p_t)

gamma=2 is a strong default for severe imbalance.

Technique 4 — Evaluation discipline

Track precision, recall, F1-score, or area under the precision-recall curve. Never use accuracy as a primary metric for imbalanced problems.

In practice, combine reweighting (always cheap) with oversampling or focal loss. Validate using stratified splits so the val set reflects real-world class distribution.

Keep practising

All Deep Learning questions

Explore further

Skip to content