How does transfer learning work for computer vision tasks?
Transfer learning reuses a network pre-trained on a large dataset (typically ImageNet) as a feature extractor or starting point for a new task. Early layers learn general features (edges, textures) that transfer well across domains; later layers encode task-specific patterns and are fine-tuned or replaced. This dramatically reduces the labelled data and compute needed for the target task.
How to think about it
Describe the spectrum from frozen backbone to full fine-tuning, give concrete guidance on which to use when, and mention learning rate scheduling — that’s the full picture.
Why it works
A model trained on 1.2 million ImageNet images has learned a rich hierarchy of visual features. The first few layers detect low-level structure (oriented edges, colour blobs) that appears in virtually any natural image. Transferring these features means you don’t need to relearn basic vision from scratch on your small target dataset.
The transfer spectrum
Feature extraction (frozen backbone)
Freeze all pre-trained weights. Replace the final classification head with a new head for your N classes, then train only that head.
- Best when: target dataset is small (< a few thousand images) and similar domain to ImageNet
- Risk: minimal — no danger of overwriting useful features
Partial fine-tuning
Freeze early layers; unfreeze the last few conv blocks plus the head. Train the unfrozen portion with a small learning rate.
- Best when: moderate data, domain differs somewhat from ImageNet (e.g., medical images, satellite imagery)
Full fine-tuning
Unfreeze the entire network. Use a much smaller learning rate for pre-trained layers (typically 10× smaller) than for the new head.
- Best when: large target dataset and/or very different domain
Practical steps in PyTorch
import torchvision.models as models
model = models.resnet50(weights="IMAGENET1K_V2")
# Freeze backbone
for param in model.parameters():
param.requires_grad = False
# Replace head for 10-class task
model.fc = torch.nn.Linear(model.fc.in_features, 10)
# Only head parameters in optimiser
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
For partial fine-tuning, selectively unfreeze model.layer4 and add its parameters to the optimiser with lr=1e-4.
Domain gap matters
| Source → Target | Recommended strategy |
|---|---|
| ImageNet → general photos | Frozen backbone or light fine-tuning |
| ImageNet → medical CT | Fine-tune last 2 blocks, slow LR |
| ImageNet → satellite | Full fine-tuning with warmup |
| ImageNet → ImageNet subset | Head-only or light fine-tuning |