datarekha
Deep Learning Hard Asked at GoogleAsked at MetaAsked at OpenAIAsked at DeepMind

Your model's training loss isn't dropping at all. How do you systematically debug it?

The short answer

A flat or erratic loss almost always indicates a bug — in data loading, label encoding, loss function, or gradient flow — not an insufficiently tuned learning rate. Systematic debugging means isolating each component and verifying it works on a tiny, controlled example before scaling up.

How to think about it

The first instinct is to adjust the learning rate. Resist this. A flat loss is almost always a bug in the pipeline, not a hyperparameter problem.

Systematic debugging checklist

Step 1 — Overfit a single batch

# Take one batch, run 200 steps on that batch only
for _ in range(200):
    loss = criterion(model(x_single), y_single)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

If loss does not drop to near-zero on one batch, the model or loss function is broken regardless of data. Fix this first.

Step 2 — Check labels and data loading

Visualise a batch: print inputs, labels, and class distributions. Common bugs:

  • Label indices off by one (class 0 mapped to index 1, etc.)
  • DataLoader accidentally shuffling labels but not inputs
  • Targets wrong dtype — e.g., float instead of long for CrossEntropyLoss

Step 3 — Verify loss function is correct

# CrossEntropyLoss expects raw logits + class indices (NOT one-hot)
criterion = nn.CrossEntropyLoss()
# Wrong: model output through softmax first, then CE → double-softmax
# Right: raw logits directly

Step 4 — Check gradient flow

for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, param.grad.abs().mean().item())
    else:
        print(name, "NO GRADIENT")  # disconnected subgraph

A parameter with None gradient is not receiving any learning signal.

Step 5 — Check learning rate

Too high: loss oscillates or diverges. Too low: loss moves, but imperceptibly slowly. Plot the loss curve at 1e-6, 1e-4, 1e-2 on the single-batch test from Step 1 to find the right order of magnitude.

Step 6 — Inspect activation statistics

Saturated sigmoids or tanh, or dead ReLUs (all-zero outputs), produce near-zero gradients. Print mean and standard deviation of each layer’s activations to detect this.

Keep practising

All Deep Learning questions

Explore further

Skip to content