What is Transfer Learning? A Guide for Beginners.

Suppose you have a problem you want to solve with computer vision but few images on which you can base your new model. What can you do? You could wait to collect more data, but this may be untenable if the features you want to capture are hard to find (i.e. rare animals in the wild, product defects).

This is where transfer learning comes in. In this article, we're going to discuss:

  1. What is transfer learning?
  2. How does transfer learning work?
  3. When you should use transfer learning

Let's get started!

What is transfer learning?

Transfer learning is a computer vision technique where a new model is built upon an existing model. The purpose of this is to encourage the new model to learn features from the old one so that the new model can be trained to its purpose faster and with less data.

The name "transfer learning" is telling as to what this technique means: you transfer the knowledge one model has acquired to a new model that can benefit from that knowledge. This is similar to how you might transfer your knowledge of painting into drawing – color theory, your aesthetic – even though the two tasks are different.

Let's talk through an example of transfer learning in action.

An Example of Transfer Learning

Imagine having images of animals collected on a safari in Africa. The dataset consists of images of giraffes and elephants. Now suppose you want to build a model that can distinguish giraffes from elephants given those images as inputs to part of a model that counts wildlife in a particular area.

The first thing that might come to your mind is to build an image recognition model from scratch that will do this. Unfortunately, you only took few photos, so it is unlikely that you will be able to achieve high accuracy.

So, you decide to look for new images to enlarge your dataset and label the new images and then train a model from scratch. This would be doable. However, assuming you find the new images suitable for your domain, it is very time-consuming.

But suppose you have a model that has been trained on, for example, millions of images that can distinguish dogs from cats.

What we can do is take this already-trained model and leverage what the model has already learned to teach it to distinguish other kinds of animals (giraffes and elephants, in our case) without having to train a model from scratch that might require more than just a lot of amount of data that in this case, we do not have, but also a lot of computational complexity.

Exploiting the knowledge of an already trained model to create a new model specialized on another task is transfer learning.

How Transfer Learning Works

How is it possible that a model that can recognize cats and dogs can be used to recognize giraffes and elephants? Great question.

Convolutional networks extract features from high-level images. The first layers of a CNN learn to recognize general features such as vertical edges, the subsequent layers, the horizontal edges, and then maybe these features are combined to recognize the corners, circles, etc.

These high-level features are independent of the type of entity we need to recognize. Computer vision models don't just "learn" exactly what, for example, cats look like. Instead, models break down images into small components and learn how those small components combine to make features associated with a particular concept.

Recognition in the entity (animals in this case) happens at the linear layers that take as input the features extracted from the convolutional layers and learn to categorize in the final class (giraffe or elephant).

To apply transfer learning, we remove the linear layers of the already trained model (since they are layers that have been trained to recognize other classes) and add new ones. We retrain the new layers in such a way that they specialize in recognizing our classes of interest.

How to Apply Transfer Learning

To apply transfer learning, first choose a model that was trained on a large dataset to solve a similar problem. A common practice is to grab models from computer vision literature such as VGG, ResNet, and MobileNet.

A pre-trained model.

Next, remove the old classifier and output layer.

A pre-trained model without the classifier and output layer.

Next, add a new classifier. This involves adapting the architecture to solve the new task. Usually, this stage means adding a new randomly initialized linear layer (represented by the blue block below) and another one with several units that is equal to the number of classes you have in your dataset (represented by the pink block).

A pre-trained model with a new classifier and new output layer.

Next, we need to freeze the Feature Extractor layers from the pre-trained model. This is an important step. If you don't freeze the feature extractor layers, your model will re-initialize them. If this happens, you will lose all the learning that has already taken place. This will be no different from training the model from scratch.

The final step is to train the new layers. You only need to train the new classifier on the new dataset.

Once you have done the previous step, you will have a model that can make predictions on your dataset. Optionally, you can improve its performance through fine-tuning. Fine-tuning consists of unfreezing parts of the pre-trained model and continuing to train it on the new dataset to adapt the pre-trained features to the new data. To avoid overfitting, run this step only if the new dataset is large and with a lower learning rate.

When to Use Transfer Learning

Let's discuss in which scenarios it is convenient to use transfer learning as well as where it is not convenient to do so.

It is worth using Transfer Learning when you have:

  • A low quantity of data: Working with too little data will result in poor model performance. The use of a pre-trained model helps create more accurate models. It will take less time to get a model up and running because you don't need to spend time collecting more data.
  • A limited amount of time: Training a machine learning model can take a long time. When you don't have much time – for example, to create a prototype to validate an idea – it is worth considering whether transfer learning is appropriate.
  • Limited computation capabilities: Training a machine learning model with millions of images requires a lot of computation. Someone has already done the hard work for you, giving you a good set of weights you could use for your task. This reduces the amount of computation – and therefore equipment – required to train your model.

When You Shouldn't Use Transfer Learning

On the other hand, transfer learning is not appropriate when:

  • There is a mismatch in domain: most of the time, transfer learning does not work if the data with which the pre-trained model has been trained is very different from the data that we will use to do transfer learning. It is necessary for the two datasets to be similar in what they predict (i.e. training a defect classifier based on a dataset with similar products that show annotated scratches and dents).
  • You need to use a large dataset: Transfer learning may not have the expected effect on tasks that require larger datasets. As we add more data, the performance of the pre-trained model gets worse. The reason is that as we increase the size of the fine-tuned dataset, we are adding more noise to the model. Since the pre-trained model performs well on the pre-trained dataset, it may be stuck in the local minimum point, and it cannot adapt to the new noise at all. In case we have a large dataset, we should consider training the model from scratch so our model can learn key features from our dataset.

Key Takeaways On Transfer Learning

Transfer learning models focus on storing knowledge gained while solving one problem and applying it to a different but related problem. Instead of training a neural network from scratch, many pre-trained models can serve as the starting point for training. These pre-trained models give a more reliable architecture and save time and resources.

You may want to consider using transfer learning when you either have a limited amount of data, lack of time, or limited computation capabilities.

You should not use transfer learning when the data you have is different from the data that the pre-trained model has been trained with, or if you have a large dataset. In these two cases, it will be better to train a model from scratch.

Now you have the information you need to understand the basics of transfer learning and when it is and is not useful. Happy model building!