PaliGemma, released by Google in May 2024, is a Large Multimodal Model (LMM). You can use PaliGemma for Visual Question Answering (VQA), to detect objects on images, or even generate segmentation masks.

While PaliGemma has zero-shot capabilities – meaning the model can identify objects without fine-tuning – such abilities are limited. Google strongly recommends fine-tuning the model for optimal performance in specific domains.

One domain where foundational models typically do not perform well is medical imaging. In this guide, we will walk through fine-tuning PaliGemma to detect fractures in X-ray images. To do this, we will use one of the datasets available on Roboflow Universe.

JAX/FLAX PaliGemma 3B is available in three different versions, differing in input image resolution (224, 448, and 896) and input text sequence length (128, 512, and 512 tokens respectively).

To limit GPU memory consumption and enable fine-tuning in Google Colab, we will use the smallest version, paligemma-3b-pt-224, in this tutorial. You will need a GPU runtime with at least 12GB of available RAM, and Google Colab with an NVIDIA T4 is sufficient.

💡
Open the notebook that accompanies this guide.

To fine-tune PaliGemma, we will:

  1. Download the object detection dataset in PaliGemma JSONL format;
  2. Install the required dependencies;
  3. Download pre-trained PaliGemma weights and tokenizer from Kaggle;
  4. Finetune PaliGemma using JAX;
  5. Save our model for later use.
💡
The model will be fine-tuned using JAX, a low-level deep learning framework by Google. As a result, code snippets for loading data, training the model, and evaluating results may be lengthy and will not be included in full in this blog post. The complete code can be found in the accompanying notebook.

Without further ado, let’s get started!

Step #1: Download an object detection dataset

To fine-tune PaliGemma for object detection, you need a dataset in the PaliGemma JSONL format. This format is not typically used for training traditional computer vision models like YOLO but is commonly used for training language models. A dataset in JSONL format has each line as a separate JSON object, like a list of individual records.

In our case, each record contains the name of the associated image, a prefix (prompt) that will be passed to the model, and a suffix (expected response) from the model. Here is a single object from our dataset:

{'image': 'n_0_2513_png_jpg.rf.1f679ff5dec5332cf06f6b9593c8437b.jpg', 'prefix': 'detect fracture', 'suffix': '<loc0390><loc0241><loc0472><loc0440> fracture'}

In the prompt, pay attention to the keyword detect followed by a list of classes we want to `detect`, separated by semicolons. The expected detection result is described by a bounding box in '<loc{Y1}><loc{X1}><loc{Y2}><loc{X2}>' and the class name. The values X1, Y1, X2, and Y2 describe the location of the bounding box, normalized to an image size of 1024x1024. Each value should have 4 digits; if a coordinate is shorter, it is padded with zeros.

Roboflow has full support for the PaliGemma JSONL format, and it can be used to export any of the 250,000+ datasets on Roboflow Universe.

First, install the required dependencies to download and parse a dataset:

pip install roboflow supervision

For this guide, we will download a fracture detection dataset using a Roboflow API key:

from google.colab import userdata
from roboflow import Roboflow

ROBOFLOW_API_KEY = userdata.get('ROBOFLOW_API_KEY')

rf = Roboflow(api_key=ROBOFLOW_API_KEY)
project = rf.workspace("srinithi-s-tzdkb").project("fracture-detection-rhud5")
version = project.version(4)
dataset = version.download("PaliGemma")

Before we start fine-tuning, let's ensure the dataset is correctly formatted by visualizing one of the examples from our dataset.

from PIL import Image
import json

first = json.loads(open(f"{dataset.location}/dataset/_annotations.train.jsonl").readline())
print(first)

image = Image.open(f"{dataset.location}/dataset/{first.get('image')}")
CLASSES = first.get('prefix').replace("detect ", "").split(" ; ")
detections = from_pali_gemma(first.get('suffix'), image.size, CLASSES)

sv.BoundingBoxAnnotator().annotate(image, detections)

Now that we know our annotations are correctly displayed, we can set up our Python environment and start fine-tuning. Most of the code in this section comes from the official Google Colab released by the PaliGemma team.

Step #2: Model setup

To train a PaliGemma model for object detection, we are going to use the big_vision project maintained by Google Research. We can install this project using the following code:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

Once you have installed big_vision, you next need to download the PaliGemma model weights. These weights are available on Kaggle. You will need a Kaggle account to download the weights. You must agree to the PaliGemma terms of service in Kaggle in order to use the model weights.

Once you have set up your Kaggle account and agreed to the terms of service, you can download the PaliGemma weights using the following code:

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

import os
import kagglehub

MODEL_PATH = "./PaliGemma-3b-pt-224.f16.npz"
if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  # Note: kaggle archive contains the same checkpoint in multiple formats.
  # Download only the float16 model.
  MODEL_PATH = kagglehub.model_download('google/PaliGemma/jax/PaliGemma-3b-pt-224', MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./PaliGemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/PaliGemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

Step #3: Train a PaliGemma model for object detection

With the model weights downloaded, we are now ready to train a PaliGemma model on a custom object detection dataset. The code for this step is long, so this guide will not include the code. Follow the accompanying notebook for all of the code you need to train your model.

The steps that we need to follow to train a model are:

  1. Import all of the required dependencies
  2. Construct the model using the ml_collections library.
  3. Load the model weights into RAM for use in training.
  4. Move parameters to GPU/TPU memory for use in training.
  5. Define preprocessing functions for images and tokens.
  6. Define a training loop that will iterate over all of the train and validation examples, using the PaliGemma jsonl format.
  7. Run a training loop with a specified learning rate and number of examples to fine-tune the model.

All of these steps are documented in the Colab notebook that accompanies this post.

In our Colab, we set the batch size to 8, the learning rate to 0.01, and define the number of train and evaluation steps as:

BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.01

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 8

With a trained model, we can now test it.

Step #4: Test the fine-tuned object detection model

In our Colab notebook, we declare a function called make_predictions which takes in a function that iterates over images and runs inference on each image.

We can use this function to test our fine-tuned object detection model:

html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))

Here is a selection of results from our model when run on the validation dataset for our project:

In this image, there are images from the validation set, with pink bounding boxes that correspond to detections from the model, and text labels on the right that tell us what class was identified (“fracture”).

You can save your model using the following code for later use:

flat, _ = big_vision.utils.tree_flatten_with_names(params)
with open("/content/fine-tuned-PaliGemma-3b-pt-224.f16.npz", "wb") as f:
  np.savez(f, **{k: v for k, v in flat})

Roboflow is actively working on a solution for deploying PaliGemma models on your own hardware that will consume the saved weights. We will update this guide when our deployment solution is available. For now, you can deploy the default weights using Roboflow Inference.

Conclusion

PaliGemma is a multimodal vision model developed by Google. PaliGemma can be used to identify the location of objects in an image, and identify segmentation masks that correspond with specific objects in an image.

In this guide, we walked through how to fine-tune PaliGemma for object detection using a custom dataset, with reference to a notebook adapted from Google’s official PaliGemma fine-tuning notebook.

We downloaded a compatible dataset from Roboflow Universe, visually checked to ensure annotations were correctly stored in the PaliGemma format, then ran a training job on Google Colab. We then tested our model with the corresponding validation dataset for our project, achieving strong results.