Your model's training loss isn't dropping at all. How do you systematically debug it?
A flat or erratic loss almost always indicates a bug — in data loading, label encoding, loss function, or gradient flow — not an insufficiently tuned learning rate. Systematic debugging means isolating each component and verifying it works on a tiny, controlled example before scaling up.
How to think about it
The first instinct is to adjust the learning rate. Resist this. A flat loss is almost always a bug in the pipeline, not a hyperparameter problem.
Systematic debugging checklist
Step 1 — Overfit a single batch
# Take one batch, run 200 steps on that batch only
for _ in range(200):
loss = criterion(model(x_single), y_single)
loss.backward()
optimizer.step()
optimizer.zero_grad()
If loss does not drop to near-zero on one batch, the model or loss function is broken regardless of data. Fix this first.
Step 2 — Check labels and data loading
Visualise a batch: print inputs, labels, and class distributions. Common bugs:
- Label indices off by one (class 0 mapped to index 1, etc.)
DataLoaderaccidentally shuffling labels but not inputs- Targets wrong dtype — e.g., float instead of long for
CrossEntropyLoss
Step 3 — Verify loss function is correct
# CrossEntropyLoss expects raw logits + class indices (NOT one-hot)
criterion = nn.CrossEntropyLoss()
# Wrong: model output through softmax first, then CE → double-softmax
# Right: raw logits directly
Step 4 — Check gradient flow
for name, param in model.named_parameters():
if param.grad is not None:
print(name, param.grad.abs().mean().item())
else:
print(name, "NO GRADIENT") # disconnected subgraph
A parameter with None gradient is not receiving any learning signal.
Step 5 — Check learning rate
Too high: loss oscillates or diverges. Too low: loss moves, but imperceptibly slowly. Plot the loss curve at 1e-6, 1e-4, 1e-2 on the single-batch test from Step 1 to find the right order of magnitude.
Step 6 — Inspect activation statistics
Saturated sigmoids or tanh, or dead ReLUs (all-zero outputs), produce near-zero gradients. Print mean and standard deviation of each layer’s activations to detect this.