Can a computer tell the difference between a dandelion and a daisy? In this post we put these philosophical musings aside, and dive into the the code necessary to find the answer.

We walk through the steps necessary to train a custom image classification model from the Resnet34 backbone using the fastai library and all its underlying PyTorch operations. At the end, you will have a model that can distinguish between your custom classes.

Resources included in this tutorial:

The Resnet Model

Resnet is a convolutional neural network that can be utilized as a state of the art image classification model. The Resnet models we will use in this tutorial have been pretrained on the ImageNet dataset, a large classification dataset. Tiny ImageNet alone contains over 100,000 images across 200 classes. The large ImageNet dataset contains a vast array of image classes and there is a good chance that images similar to yours will have been used in pretraining.

To stand on the shoulders of giants, we will start our model from the pretrained checkpoint and fine tune our Resnet model from this base state. This is a process also often called "transfer learning".

The Resnet34 layer architecture on the right (source)

For more info on Resnet, I recommend checking out the paper.

Let's dive in!

Data Preparation in Roboflow

In this tutorial, we will use Roboflow as the dataset source of record, organization, preprocessor, and augmenter.

If you don't already have your own classification dataset, you can feel free to use the Public Flower Classification dataset to follow along directly with the tutorial.

Assuming you have your own dataset, the first step is to upload you data to Roboflow. To load a classification dataset, separate your images into separate folders according to class names.

Example folder structure for classification dataset

Then, sign up for a free account at roboflow.com, and hit Create New Dataset. There you can simply drag and drop you image classification dataset into the Roboflow platform. Note: to stay in the free tier, you may want to downsize the dataset to less than 1000 images.

Next, we can go ahead and choose Preprocessing and Augmentation settings in the Roboflow platform to create a dataset version of our original training data. Preprocessing standardizes the dataset across train, validation, and test splits. Augmentation creates new images from the base training set to help your prevent your model from overfitting.

Choosing preprocessing and augmentations in Roboflow

For our dataset, I have created an augmented dataset version that includes Crop, Rotation, Brightness, Exposure, and Cutout augmentations. I have also generated 5 extra images per base train set image. This results in a large dataset of 6921 images.

Once you are satisfied with your dataset version, hit Generate then Download and then Show Link to receive a curl link that you can bring into the Colab Notebook for dataset import.

Entering the notebook: How to Train a Custom Resnet34 Model for Image Classification Colab Notebook

We recommend having the notebook and blog post open simultaneously. Open and File then Save Copy in Drive. This will allow you to edit it with your own code.

Installing Fastai Dependencies

In the Colab Notebook we will install the fastai library and import everything from fastai.vision. This will provide us with many of the tools we will need later in training.

Downloading Classification Dataset from Roboflow

The next step is to download your classification dataset from Roboflow.

Above, when creating a dataset version, you will receive a curl link from Show Link. Copy and paste that into the notebook where it reads [YOUR LINK HERE].

Downloading custom image classification data from Roboflow

After our dataset has been downloaded, we will load it into the fastai data loader, normalizing it to the mean and standard deviation of the ImageNet dataset.

We can take a peak at our batch to make sure the data has loaded in correctly.

A quick check shows that our data has loaded in correctly

Download a Custom Resnet Image Classification Model

For the next step, we download the pretrained Resnet model from the torchvision model library.

learn = create_cnn(data, models.resnet34, metrics=error_rate)

In this tutorial we implement Resnet34 for custom image classification, but every model in the torchvision model library is fair game. So in that sense, this is also a tutorial on:

  • How to train a custom Resnet18 image classification model
  • How to train a custom Resnet50 image classification model
  • How to train a custom Resnet101 image classification model
  • How to train a custom Resnet152 image classification model
  • How to train a custom Squeezenet image classification model
  • How to train a custom VGG image classification model

If you aren't seeing the performance you need, try using a larger model (like Resnet152 with 152 layers).

printing out the ResNet architecture

Train a Custom Resnet Image Classification Model

Frozen Training

After initializing our model we will take a first pass at training by finetuning the last layer of the model - the rest of the model is frozen. This gives the model a chance to learn the relevant pretrained features.

from fastai.callbacks import *
early_stop = EarlyStoppingCallback(learn, patience=20)
save_best_model = SaveModelCallback(learn, name='best_resnet34')

#frozen training step
defaults.device = torch.device('cuda') # makes sure the gpu is used
learn.fit_one_cycle(50, callbacks=[early_stop, save_best_model])

The default is 50 epochs - you can increase this to get your model to train for longer.

We implement two training callbacks - EarlyStopping and SaveModel. Early stopping will stop training if validation loss has not decreased for 20 epochs. Save Model will save the best model based on validation loss so we can recover it.

Finding the fastai Learning Rate

Next we unfreeze the model parameters and calculate the optimal learning rate going forward. Too small and your model won't learn much. Too big and you may backprop way off the map in the loss function space.

def find_appropriate_lr(model:Learner, lr_diff:int = 15, loss_threshold:float = .05, adjust_value:float = 1, plot:bool = False) -> float:
    #Run the Learning Rate Finder
    model.lr_find()
    
    #Get loss values and their corresponding gradients, and get lr values
    losses = np.array(model.recorder.losses)
    min_loss_index = np.argmin(losses)
    
    
    #loss_grad = np.gradient(losses)
    lrs = model.recorder.lrs
    
    #return the learning rate that produces the minimum loss divide by 10   
    return lrs[min_loss_index] / 10

Unfrozen Training

Next, we train the whole unfrozen model for another 50 epochs. This helps fine tune your model down for maximum performance. You can watch as the validation error rate decreases.

Using Custom Resnet Image Classification Model for Inference

We can evaluate our models performance by using it for test inference. Fastai provides a convenient method to visualize your model's confusion matrix.

Our model has done a good job of telling the difference between daisies and dandelions

And you can inspect which images are the hardest for your model to learn. Some of these daisies look like dandelions!

Hardest images for your model to learn - some of these daisies look like dandelions!

And lastly, we can run a script to run test inference on our test set, images our model has never seen.

#run inference on test images
import glob
from IPython.display import Image, display

model = learn.model
model = model.cuda()
for imageName in glob.glob('/content/test/*/*.jpg'):
    print(imageName)
    img = open_image(imageName)
    prediction = learn.predict(img)
    #print(prediction)
    print(prediction[0])
    display(Image(filename=imageName))
    print("\n")

Yielding inference:

Test inference on images the model has never seen

Saving Custom Resnet Image Classification Weights

For the last step of the notebook, we provide code to export your model weights for future use.

These are .pth PyTorch weights and can be used with the same fastai library, within PyTorch, within TorchScript, or within ONNX.

The rest of the application is up to you 🚀

Conclusion

Congratulations! You have now learned how to train a custom Resnet34 image classification model to differentiate between any type of image in the world. All it takes is the right dataset, dataset management tools, and model architecture.

And better yet, image classification is coming soon to the Roboflow Train one-click integration. Stay tuned.

We hope you enjoyed. And happy classifying!