Transfer Learning
Tags | CS 330Pretraining |
---|
Transfer learning
Multi-task learning is when you want to solve multiple tasks at once, and transfer learning is when you want to solve a target task by solving a source task first. Usually, you can’t access the source data after training
Transfer learning can be used for multi-task learning. Just train a central model and transfer this knowledge across multiple tasks (different model for each) and now you have a multi-task model.
Why do we transfer-learn?
- when the source dataset is very large and you can easily download the trained weights
- when you don’t care about solving the source and target tasks at the same time
Typical strategies
The typical workflow is as follows:
- Initialize a model trained on the source dataset
- Update the model using the target dataset
A common approach is to reinitialize the last few layers, but a key problem is that when we pass gradients backward, you’re essentially multiplying the gradient by some random number, which can destroy the learned features.
To solve this, you might use a smaller learning rate or freeze earlier layers.
Pretraining might not need diverse data, and fine-tuning on the last layer might not necessarily be the best choice in all cases. If the target task differs in low-level features, you might want to fine-tune the first few layers, etc.
General rule: train last layer with frozen features, and then fine-tune the entire network