datarekha

KL Divergence

How far is one probability distribution from another? KL divergence answers that — it is the engine behind cross-entropy loss, VAEs, distillation, and RLHF.

8 min read Advanced Math for ML Lesson 29 of 30

What you'll learn

  • Why cross-entropy loss is secretly minimizing a KL divergence to the true label distribution
  • How to compute D(P||Q) by hand and what the bits mean intuitively
  • Why KL is asymmetric and how that asymmetry changes model behavior in VAEs and distillation

Before you start

A language model was trained to predict the next token. On one test sentence it assigned probability 0.9 to the correct token and spread the remaining 0.1 across 49 999 other tokens. A second model assigned 0.5 to the correct token and spread 0.5 equally across the rest. Which model is further from the true distribution — where the correct token has probability 1.0 and every other token has probability 0.0?

Cross-entropy loss answers that question numerically. But to understand why it works, you need the quantity hiding inside it: KL divergence.

The core formula

Let P be the true distribution (what the data actually is) and Q be the model distribution (what your model guesses). KL divergence from Q to P — written D(P||Q) and read “KL of P from Q” — is:

D(P||Q) = sum over all x of  P(x) * log2( P(x) / Q(x) )

This lesson uses log base 2, so the unit is bits. (Some texts use the natural log; the unit then is nats. The formula and all properties are identical — only the scale changes.)

What one term means

Pick a single outcome x. The fraction P(x) / Q(x) measures how surprised you are: if your model Q says outcome x has probability 0.1 but the truth P says 0.5, the ratio is 5 — your model underestimated this outcome by a factor of 5, so it will be surprised when it shows up. log2(5) converts that surprise ratio into bits. Multiplying by P(x) weights by how often the outcome actually occurs. The sum collects these weighted surprises across every outcome.

Plain-English summary: D(P||Q) is the average extra bits you need per observation because you coded data with Q instead of P.

Three essential properties

1. Always non-negative: D(P||Q) >= 0 for any two distributions. This follows from Jensen’s inequality applied to the concave log function. (If you want the proof: log(x) <= x - 1, apply it to each term of the sum, and the inequality collapses to zero.)

2. Zero if and only if equal: D(P||Q) = 0 exactly when P(x) = Q(x) for every x. Once the distributions match perfectly, there is no extra surprise.

3. Asymmetric: D(P||Q) is almost never equal to D(Q||P). This is the defining difference between a divergence and a distance (distances, like Euclidean, are symmetric by definition).

Worked example: two coins

Let P = [0.5, 0.5] (a fair coin — heads and tails equally likely) and Q = [0.9, 0.1] (a biased model that thinks heads is nine times more likely than tails).

Computing D(P||Q):

D(P||Q) = 0.5 * log2(0.5 / 0.9)  +  0.5 * log2(0.5 / 0.1)
        = 0.5 * log2(0.5556)       +  0.5 * log2(5.0)
        = 0.5 * (-0.8480)          +  0.5 * (2.3219)
        = -0.4240 + 1.1610
        = 0.737 bits

When the fair coin lands tails (probability 0.5 under P), your biased model assigns only 0.1 — a ratio of 5, costing 2.32 extra bits. That tail term dominates. The model’s over-confidence on heads partially cancels it, but the net penalty is 0.737 bits per observation.

Computing D(Q||P):

D(Q||P) = 0.9 * log2(0.9 / 0.5)  +  0.1 * log2(0.1 / 0.5)
        = 0.9 * log2(1.8)          +  0.1 * log2(0.2)
        = 0.9 * (0.8480)           +  0.1 * (-2.3219)
        = 0.7632 + (-0.2322)
        = 0.531 bits

The two directions give different numbers — 0.737 bits vs 0.531 bits. The playground below confirms both.

Output:

D(P||Q) = 0.7370 bits  (P=fair, Q=biased)
D(Q||P) = 0.5310 bits  (Q=biased, P=fair)
Symmetric? False

The distribution diagram

The figure below shows P and Q side by side, with the per-outcome surprise contributions shaded for the D(P||Q) direction.

P = fair [0.5, 0.5] vs Q = biased [0.9, 0.1]Shaded region = per-outcome surprise contribution to D(P||Q)0.00.51.0P=0.5Q=0.9Heads−0.424 bitsP=0.5Q=0.1Tails+1.161 bitsP (true)Q (model)negative contrib.positive contrib.

Heads: Q over-estimates (Q=0.9 vs P=0.5), contributing −0.424 bits. Tails: Q severely under-estimates (Q=0.1 vs P=0.5), contributing +1.161 bits. Net = 0.737 bits.

The cross-entropy identity

The entropy of P, written H(P), is the average bits needed with a perfect code:

H(P) = - sum over x of  P(x) * log2( P(x) )

The cross-entropy of P and Q, written H(P, Q), is the average bits needed when you code P-data using Q-codes:

H(P, Q) = - sum over x of  P(x) * log2( Q(x) )

Subtract one from the other and you get exactly D(P||Q):

H(P, Q)  =  H(P)  +  D(P||Q)

This identity is why minimizing cross-entropy loss is the same as minimizing the KL divergence to the true label distribution. When the training labels form a one-hot distribution (one class gets probability 1, all others get 0), H(P) is zero and H(P, Q) = D(P||Q). Every gradient step that shrinks cross-entropy also shrinks the KL gap between your model and the truth.

Where KL shows up in practice

Variational Autoencoders (VAEs). The training loss has two terms: a reconstruction term and a KL term D(q(z|x) || p(z)), where q(z|x) is the encoder’s approximate posterior and p(z) is a standard Gaussian prior. The KL term penalizes the encoder for drifting too far from the prior, keeping the latent space smooth enough to sample from.

Knowledge distillation. A small student model is trained to match the soft probability outputs of a large teacher model. Minimizing D(teacher || student) in the forward direction (P = teacher, Q = student) encourages the student to cover all the modes the teacher puts weight on — including the small probabilities on “almost-right” answers that contain generalization signal.

RLHF and PPO. When fine-tuning a language model on human feedback, a KL penalty D(policy || reference) is added to the reward to stop the policy from drifting too far from the base model. Without it, the model can collapse to a narrow set of reward-hacking outputs.

The asymmetry matters in practice

In the forward KL direction (D(P||Q), P true, Q model): wherever P assigns probability, Q is forced to assign some too — otherwise the log ratio blows up. This makes Q try to cover all modes of P. It is sometimes called mean-seeking or mass-covering.

In the reverse KL direction (D(Q||P), Q model, P true): wherever Q assigns probability, P must assign some — the penalty is infinite if Q puts mass where P has none. This makes Q choose one mode and commit. It is sometimes called mode-seeking.

VAEs use forward KL (for the KL penalty on the encoder). Variational inference in general uses reverse KL because optimizing D(Q||P) is tractable when P is the intractable posterior. The choice is not cosmetic — it changes what the model learns to approximate.

Quiz

Quick check

0/3
Q1P = [0.5, 0.5], Q = [0.9, 0.1]. What is D(P||Q) rounded to two decimal places, in bits?
Q2A classifier's cross-entropy loss on a one-hot label distribution decreases from 1.2 to 0.4. What happened to D(model output || true labels)?
Q3A new generative model uses reverse KL — D(Q||P) — where Q is the model and P is the data distribution. Compared to a model trained with forward KL, you expect the reverse-KL model to produce outputs that are:

Next

Explore Jensen-Shannon divergence — the symmetrized, bounded cousin of KL, and the loss function behind the original GAN training objective.

Sign in to track your progress

Completed lessons, your XP, level, and streak save to your account — it's free and takes a few seconds.

Practice this in an interview

All questions
What do skewness and kurtosis measure, and what are their practical implications?

Skewness measures the asymmetry of a distribution's tails — positive skew means a longer right tail, negative skew a longer left tail. Kurtosis measures the heaviness of the tails relative to a normal distribution; excess kurtosis above zero indicates more probability mass in the tails and peak than a Gaussian, which matters for risk and outlier frequency.

Why use cross-entropy loss instead of MSE for classification?

MSE treats class probabilities as continuous values and produces tiny, saturating gradients when a sigmoid output is near 0 or 1, stalling learning. Cross-entropy is the proper log-likelihood loss for categorical distributions; it keeps gradients large and informative even when the network is very wrong, and its minimum aligns with the true class probabilities.

Why does training loss keep falling while validation loss rises?

This divergence is the signature of overfitting: the model has enough capacity to memorise training-set specifics — noise, label errors, dataset-specific patterns — that do not generalise. Training loss measures fit to what has already been seen; validation loss measures generalisation to held-out data. As the model memorises rather than learns structure, it scores better on training data and worse on everything else.

What is log loss and why does it penalise confident wrong predictions more than uncertain ones?

Log loss (cross-entropy loss) measures how well a model's predicted probabilities match the true labels: it is the negative log-likelihood of the correct class. It penalises confident wrong predictions severely because log(p) approaches negative infinity as p approaches zero — predicting 0.99 for the wrong class incurs roughly 100x the penalty of predicting 0.6 for the wrong class. A perfect model achieves 0; a random binary classifier achieves ln(2) ≈ 0.693.

Related lessons

Explore further

Skip to content