datarekha

The training loop

Forward, loss, backward, step, zero-grad — the five-line ritual that turns a pile of layers into a model that learns. Built and run from scratch.

8 min read Beginner Deep Learning Lesson 4 of 27

What you'll learn

  • The five steps every training loop runs, and what each one does
  • Why forgetting zero_grad() silently breaks training
  • The epoch / batch structure and the train vs eval split

Before you start

You’ve met the parts — tensors, autograd, activations, a loss function. But a pile of parts is not a model that learns. The thing that turns parts into learning is a short ritual you will type ten thousand times in your career, and it is only five steps long:

1. forward    pred = model(x)            # run the network
2. loss       L = loss_fn(pred, y)       # how wrong was it?
3. backward   L.backward()               # autograd fills every .grad
4. step       optimizer.step()           # nudge each weight downhill
5. zero       optimizer.zero_grad()      # wipe grads for the next round

Steps 1–2 are the forward pass — predict, then score. Steps 3–5 are the backward pass — find the slope of the loss with respect to every weight, take one small step down that slope, then reset. Run that loop enough times on enough data and the weights drift to values that make the loss small. That drift is learning.

batchx, ymodel(x)forwardloss_fnhow wrong?lossbackward: gradients flow the other wayoptimizer.step() → optimizer.zero_grad()nudge every weight downhill, then wipe the gradsrepeat for the next batch
One iteration: forward to a loss, backward to gradients, step the weights, zero the grads, repeat.

Step through one iteration at a time. Watch the loss fall as the weights move — then flip off zero_grad() and watch the gradient pile up and blow training apart.

Build it for real (in NumPy)

PyTorch’s loss.backward() is convenient, but the loop has no magic in it. Here is the entire ritual on a tiny linear-regression problem, with the gradient computed by hand so you can see exactly what step() consumes. Run it and watch the loss drop each epoch.

The loss falls, and w and b crawl toward 2 and 1. That is the whole of supervised learning — a loop that keeps nudging parameters in the direction that shrinks the loss.

The same loop in PyTorch

In real code, autograd computes the gradients for you and an optimizer holds the update rule. The five steps map one-to-one:

model = nn.Linear(1, 1)
opt = torch.optim.SGD(model.parameters(), lr=0.1)
loss_fn = nn.MSELoss()

for epoch in range(20):
    for xb, yb in loader:          # one batch at a time
        pred = model(xb)           # 1. forward
        loss = loss_fn(pred, yb)   # 2. loss
        loss.backward()            # 3. backward — fills p.grad for every param
        opt.step()                 # 4. step  — uses p.grad to update p
        opt.zero_grad()            # 5. zero  — reset .grad to 0 for next batch

Epochs, batches, and the train/val split

Two structural details turn the bare loop into real training:

  • Batches and epochs. You rarely feed all data at once. You split it into batches (say 32 examples), run the five steps per batch, and one full pass over the dataset is one epoch. You train for many epochs. Batch size and learning rate interact in ways worth their own lesson.
  • Train vs validation. You train on one split and watch a held-out validation split to catch overfitting. Two switches matter here:
model.train()                  # dropout/BatchNorm in TRAINING mode
# ... training loop ...

model.eval()                   # dropout off, BatchNorm uses running stats
with torch.no_grad():          # don't build the autograd graph — faster, less memory
    val_loss = loss_fn(model(x_val), y_val)

Quick check

Quick check

0/3
Q1What does optimizer.zero_grad() do, and why is it needed?
Q2What is the correct order of the five training-loop steps?
Q3Why wrap validation in `with torch.no_grad()` and call model.eval()?

Next

You now have the spine every other lesson hangs on. Next we open up step 3 — backprop by hand — to see exactly how backward() computes those gradients, then study the choices that make the loop converge: weight initialization, optimizers, and learning-rate schedules.

Sign in to track your progress

Completed lessons, your XP, level, and streak save to your account — it's free and takes a few seconds.

Practice this in an interview

All questions

Related lessons

Explore further

Skip to content