Transfer Learning
Transfer learning uses a pretrained model’s knowledge (features, weights) for a new but related task, instead of training from scratch.
Why Use It?
- Faster convergence (less data needed).
- Better performance when dataset is small.
- Leverages features learned from large datasets (e.g., ImageNet).
Two Main Strategies
1. Feature Extraction
- Freeze backbone weights → only train new head.
- Keeps pretrained features intact.
for param in model.parameters():
param.requires_grad = False
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
2. Fine-Tuning
- Unfreeze some/all layers → train with a low LR.
- Allows adaptation of pretrained features to new task.
for param in model.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
Common Operations on nn.Module
Replace Last Layer
model.fc = torch.nn.Linear(model.fc.in_features, num_classes) # ResNet example
model.classifier[6] = torch.nn.Linear(4096, num_classes) # VGG example
Remove a Layer
from torch import nn
model.fc = nn.Identity() # Acts like a no-op
Slice Sequential Models
- If your model is
nn.Sequential
, you can cut it easily:
features = nn.Sequential(*list(model.children())[:-1])
Inspect Model Parts
- If your model is complex, and to replace last layer, we need to know the name of the last layer.
for name, module in model.named_children():
print(name, module)
<p>From the output, you can decide which layers to replace. Generally, you would replace the last layer with a new one that has the correct output size for your specific task.</p>
Freezing & Unfreezing Layers
- Freeze → set
requires_grad=False
(no weight updates, no gradient memory). - Unfreeze → set
requires_grad=True
. - Can freeze entire model or specific submodules.
model.eval()
and torch.no_grad()
in Transfer Learning
model.eval()
- Sets layers like Dropout and BatchNorm to inference mode (no randomness, uses stored stats).
- Does not disable gradient calculation.
torch.no_grad()
- Disables gradient tracking (saves memory, speeds up inference).
- Does not change layer behavior.
Typical inference pattern:
model.eval()
with torch.no_grad():
outputs = model(inputs)
Use during:
- Validation
- Testing
- Feature extraction (when you’re not updating weights)
Good Practices
- Lower LR when fine-tuning pretrained layers (
~1e-4
or smaller). - Normalize inputs with same mean/std as the pretrained model.
- Freeze/unfreeze in stages for stable training.
- Always call
model.eval()
+torch.no_grad()
for inference/feature extraction. - Track LR and
requires_grad
status:
[p.requires_grad for p in model.parameters()]
Example: ResNet Feature Extraction
import torchvision.models as models
resnet = models.resnet18(weights='IMAGENET1K_V1')
# Freeze all layers
for param in resnet.parameters():
param.requires_grad = False
# Replace classifier
resnet.fc = torch.nn.Linear(resnet.fc.in_features, num_classes)
# Inference
resnet.eval()
with torch.no_grad():
features = resnet(inputs)