datarekha
Machine Learning Medium Asked at MetaAsked at LinkedInAsked at Stripe

What is stratified k-fold cross-validation and when is it necessary?

The short answer

Stratified k-fold ensures each fold has the same class-label proportions as the full dataset. It is necessary for imbalanced classification because standard random k-fold can produce folds where a minority class is entirely absent, making per-fold metrics undefined or severely misleading.

How to think about it

Standard k-fold shuffles examples randomly before partitioning. On a balanced dataset this is fine. On an imbalanced dataset — say 95% negative, 5% positive — a random fold of size 200 might contain only 8 positives, or zero. Metrics like recall, F1, and AUC become undefined or noisy in that fold.

Stratified k-fold solves this by sampling each class independently and then interleaving: if the dataset is 5% positive, every fold is also approximately 5% positive.

from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.linear_model import LogisticRegression

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

scores = cross_val_score(
    LogisticRegression(max_iter=1000),
    X, y,
    cv=skf,            # pass the splitter directly
    scoring="f1"
)
print(f"F1: {scores.mean():.3f} +/- {scores.std():.3f}")

cross_val_score uses StratifiedKFold automatically when the estimator is a classifier and y is integer-encoded — but this only holds for classification. For multi-label or regression, you must supply the splitter explicitly.

When else to stratify:

  • Multi-class problems — stratify by the multi-class label to preserve per-class frequencies.
  • Combined class + group stratification — use sklearn.model_selection.StratifiedGroupKFold.
  • Train/test splits — train_test_split(..., stratify=y) for the same reason.

Stratification vs. oversampling — stratification only controls how folds are formed; it does not change the training distribution. SMOTE or class-weighted losses are separate techniques applied after splitting.

Learn it properly Train/val/test & CV

Keep practising

All Machine Learning questions

Explore further

Skip to content