Can a computer tell the difference between a dandelion and a daisy? In this post we put these philosophical musings aside, and dive into the the code necessary to find the answer.
We walk through the steps necessary to train a custom image classification model from the Resnet34 backbone using the fastai library and all its underlying PyTorch operations. At the end, you will have a model that can distinguish between your custom classes.
Resources included in this tutorial:
- Public Flower Classification dataset
- How to Train a Custom Resnet34 Model for Image Classification Colab Notebook
- Custom Image Classification YouTube
How The Resnet Model Works
Resnet is a convolutional neural network that can be utilized as a state of the art image classification model. The Resnet models we will use in this tutorial have been pre-trained on the ImageNet dataset, a large classification dataset.
Tiny ImageNet alone contains over 100,000 images across 200 classes. The large ImageNet dataset contains a vast array of image classes and there is a good chance that images similar to yours will have been used in pre-training.
To stand on the shoulders of giants, we will start our model from the pre-trained checkpoint and fine tune our Resnet model from this base state. This is a process also often called "transfer learning".
For more info on Resnet, I recommend checking out the paper.
Data Preparation to Train a Custom Resnet34 Model
In this tutorial, we will use Roboflow as the dataset source of record, organization, preprocessor, and augmenter.
If you don't already have your own classification dataset, you can feel free to use the Public Flower Classification dataset to follow along directly with the tutorial.
Assuming you have your own dataset, the first step is to upload you data to Roboflow. To load a classification dataset, separate your images into separate folders according to class names.
Then, sign up for a free account at roboflow.com, and hit
Create New Dataset. There you can simply drag and drop you image classification dataset into the Roboflow platform.
Label and Annotate Data with Roboflow for free
Use Roboflow to manage datasets, label data, and convert to 26+ formats for using different models. Roboflow is free up to 10,000 images, cloud-based, and easy for teams.
Next, we can go ahead and choose
Augmentation settings in the Roboflow platform to create a dataset version of our original training data.
Preprocessing standardizes the dataset across train, validation, and test splits. Augmentation creates new images from the base training set to help your prevent your model from overfitting.
For our dataset, I have created an
augmented dataset version that includes
Cutout augmentations. I have also generated 5 extra images per base train set image. This results in a large dataset of 6921 images.
Once you are satisfied with your dataset version, hit
Download and then
Show Link to receive a curl link that you can bring into the Colab Notebook for dataset import.
Entering the notebook: How to Train a Custom Resnet34 Model for Image Classification Colab Notebook
We recommend having the notebook and blog post open simultaneously. Open and
Save Copy in Drive. This will allow you to edit it with your own code.
Installing Fastai Dependencies to Train Resnet34
In the Colab Notebook we will install the fastai library and import everything from
fastai.vision. This will provide us with many of the tools we will need later in training.
Using Custom Data to Train Resnet34
To export your own data for this tutorial, sign up for Roboflow and make a public workspace, or make a new public workspace in your existing account. If your data is private, you can upgrade to a paid plan for export to use external training routines like this one or experiment with using Roboflow's internal training solution.
Above, when creating a dataset version, you will receive a curl link from
Show Link. Copy and paste that into the notebook where it reads
[YOUR LINK HERE].
After our dataset has been downloaded, we will load it into the fastai data loader, normalizing it to the mean and standard deviation of the ImageNet dataset.
We can take a peak at our batch to make sure the data has loaded in correctly.
Download a Custom Resnet Image Classification Model
For the next step, we download the pre-trained Resnet model from the torchvision model library.
learn = create_cnn(data, models.resnet34, metrics=error_rate)
In this tutorial we implement Resnet34 for custom image classification, but every model in the torchvision model library is fair game. So in that sense, this is also a tutorial on:
- How to train a custom Resnet18 image classification model
- How to train a custom Resnet50 image classification model
- How to train a custom Resnet101 image classification model
- How to train a custom Resnet152 image classification model
- How to train a custom Squeezenet image classification model
- How to train a custom VGG image classification model
If you aren't seeing the performance you need, try using a larger model (like Resnet152 with 152 layers).
Train a Custom Resnet Image Classification Model
After initializing our model we will take a first pass at training by fine-tuning the last layer of the model - the rest of the model is frozen. This gives the model a chance to learn the relevant pre-trained features.
from fastai.callbacks import * early_stop = EarlyStoppingCallback(learn, patience=20) save_best_model = SaveModelCallback(learn, name='best_resnet34') #frozen training step defaults.device = torch.device('cuda') # makes sure the gpu is used learn.fit_one_cycle(50, callbacks=[early_stop, save_best_model])
The default is 50 epochs - you can increase this to get your model to train for longer.
We implement two training callbacks -
SaveModel. Early stopping will stop training if validation loss has not decreased for 20 epochs. Save Model will save the best model based on validation loss so we can recover it.
Finding the fastai Learning Rate
Next we unfreeze the model parameters and calculate the optimal learning rate going forward. Too small and your model won't learn much. Too big and you may backprop way off the map in the loss function space.
def find_appropriate_lr(model:Learner, lr_diff:int = 15, loss_threshold:float = .05, adjust_value:float = 1, plot:bool = False) -> float: #Run the Learning Rate Finder model.lr_find() #Get loss values and their corresponding gradients, and get lr values losses = np.array(model.recorder.losses) min_loss_index = np.argmin(losses) #loss_grad = np.gradient(losses) lrs = model.recorder.lrs #return the learning rate that produces the minimum loss divide by 10 return lrs[min_loss_index] / 10
Next, we train the whole unfrozen model for another 50 epochs. This helps fine tune your model down for maximum performance. You can watch as the validation error rate decreases.
Using Custom Resnet Image Classification Model for Inference
We can evaluate our models performance by using it for test inference. Fastai provides a convenient method to visualize your model's confusion matrix.
And you can inspect which images are the hardest for your model to learn. Some of these daisies look like dandelions!
And lastly, we can run a script to run test inference on our test set, images our model has never seen.
#run inference on test images import glob from IPython.display import Image, display model = learn.model model = model.cuda() for imageName in glob.glob('/content/test/*/*.jpg'): print(imageName) img = open_image(imageName) prediction = learn.predict(img) #print(prediction) print(prediction) display(Image(filename=imageName)) print("\n")
Saving Custom Resnet Image Classification Weights
For the last step of the notebook, we provide code to export your model weights for future use.
.pth PyTorch weights and can be used with the same fastai library, within PyTorch, within TorchScript, or within ONNX.
The rest of the application is up to you 🚀
Your Trained Resnet34 Image Classification Model
You have now learned how to train a custom Resnet34 image classification model to differentiate between any type of image in the world. All it takes is the right dataset, dataset management tools, and model architecture.
And better yet, image classification is part of the Roboflow Train one-click integration. Try it for free today.