datarekha
Deep Learning Medium Asked at GoogleAsked at MetaAsked at Amazon

Why use cross-entropy loss instead of MSE for classification?

The short answer

MSE treats class probabilities as continuous values and produces tiny, saturating gradients when a sigmoid output is near 0 or 1, stalling learning. Cross-entropy is the proper log-likelihood loss for categorical distributions; it keeps gradients large and informative even when the network is very wrong, and its minimum aligns with the true class probabilities.

How to think about it

Cross-entropy is the principled choice because it matches the statistical model: classification networks output a probability distribution, and cross-entropy is exactly the negative log-likelihood of that distribution given the true labels.

Why MSE fails for classification

With a sigmoid output p = σ(z), the MSE gradient with respect to z is:

∂L/∂z = (p − y) · p · (1 − p)

When p is near 0 or 1 (confidently wrong), p(1 − p) ≈ 0, so the gradient nearly vanishes. The network barely updates even though it is making a large error.

Cross-entropy avoids saturation

Binary cross-entropy loss is L = −[y log p + (1−y) log(1−p)]. Its gradient w.r.t. z simplifies to just p − y — no saturation term. The bigger the error, the bigger the gradient.

import torch
import torch.nn as nn

logits = torch.tensor([2.5, -1.0, 0.3])   # raw scores, NOT softmax
targets = torch.tensor([0, 1, 1])          # integer class indices

# CrossEntropyLoss = LogSoftmax + NLLLoss; pass logits directly
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits.unsqueeze(0), targets[0:1])

Pass raw logits to nn.CrossEntropyLoss — it applies log-softmax internally for numerical stability. Applying softmax first introduces precision errors.

When MSE is appropriate

MSE suits continuous targets (regression), where outputs have no probability interpretation and squared deviation is a natural cost.

Learn it properly Loss Functions

Keep practising

All Deep Learning questions

Explore further

Skip to content