Open-Source AI Cookbook documentation
Fine tuning a VLM for Object Detection Grounding using TRL
Fine tuning a VLM for Object Detection Grounding using TRL
Authored by: Sergio Paniego
🚨 WARNING: This notebook is resource-intensive and requires substantial computational power. If you’re running it in Colab, it will utilize an A100 GPU.
🔍 What You’ll Learn
In this recipe, we’ll demonstrate how to fine-tune a Vision-Language Model (VLM) for object detection grounding using TRL.
Traditionally, object detection involves identifying a predefined set of classes (e.g., “car”, “person”, “dog”) within an image. However, this paradigm shifted with models like Grounding DINO, GLIP, or OWL-ViT, which introduced open-ended object detection—enabling models to detect any class described in natural language.
Grounding goes a step further by adding contextual understanding. Instead of just detecting a “car”, grounded detection can locate the “car on the left”, or the “red car behind the tree”. This provides a more nuanced and powerful approach to object detection.
In this recipe, we’ll walk through how to fine-tune a VLM for this task. Specifically, we’ll use PaliGemma 2, a Vision-Language Model developed by Google that supports object detection out of the box. While not all VLMs offer detection capabilities by default, the concepts and steps in this notebook can be adapted for models without built-in object detection as well.
To train our model, we’ll use RefCOCO, an extension of the popular COCO dataset, designed specifically for referring expression comprehension—that is, combining object detection with grounding through natural language.
This recipe also builds upon my recent release of this Space, which lets you compare different VLMs on object understanding tasks such as object detection, keypoint detection, and more.
📚 Additional Resources
At the end of this notebook, you’ll find extra resources if you’re interested in exploring the topic further.
1. Install dependencies
Let’s start by installing the required dependencies:
!pip install -Uq transformers datasets trl supervision albumentations
We’ll log in to our Hugging Face account to access gated models and save our trained checkpoints.
You’ll need an access token 🗝️.
from huggingface_hub import notebook_login
notebook_login()
2. 📁 Load Dataset
For this example, we’ll use RefCOCO, a dataset that includes grounded object detection annotations—enabling more robust and context-aware detection.
To keep things simple and efficient, we’ll work with a subset of the dataset.
from datasets import load_dataset
refcoco_dataset = load_dataset("jxu124/refcoco", split="train[:5%]")
After loading it, let’s see what’s inside:
refcoco_dataset
We can see that the dataset contains useful information such as the bbox
and captions
columns. In this case, bboxes follow a xyxy
format.
However, the image itself isn’t directly accessible from these fields. For more details about the image source, we can inspect the raw_image_info
column.
refcoco_dataset[13]["raw_image_info"]
2.1 🖼️ Add Images to the Dataset
While we could link each example to the corresponding image in the COCO dataset, we’ll simplify the process by downloading the images directly from Flickr.
However, this approach may result in some missing images, so we’ll need to handle those cases accordingly.
import json
import requests
from PIL import Image
from io import BytesIO
def add_image(example):
try:
raw_info = json.loads(example["raw_image_info"])
url = raw_info.get("flickr_url", None)
if url:
response = requests.get(url, timeout=10)
image = Image.open(BytesIO(response.content)).convert("RGB")
example["image"] = image
else:
example["image"] = None
except Exception as e:
print(f"Error loading image: {e}")
example["image"] = None
return example
refcoco_dataset_with_images = refcoco_dataset.map(add_image, desc="Adding image from flickr", num_proc=16)
Awesome! Our images are now downloaded and ready to go.
refcoco_dataset_with_images
Next, let’s filter the dataset to include only samples that have an associated image:
filtered_dataset = refcoco_dataset_with_images.filter(
lambda example: example["image"] is not None, desc="Removing failed image downloads"
)
2.2 Remove Unneeded Columns
filtered_dataset
The dataset contains many columns that we won’t need for this task.
Let’s simplify it by keeping only the 'bbox'
, 'captions'
, and 'image'
columns.
filtered_dataset = filtered_dataset.remove_columns(
[
"sent_ids",
"file_name",
"ann_id",
"ref_id",
"image_id",
"split",
"sentences",
"category_id",
"raw_anns",
"raw_image_info",
"raw_sentences",
"image_path",
"global_image_id",
"anns_id",
]
)
It looks much better now!
filtered_dataset
2.3 Separate Captions into Unique Samples
One final step: each sample currently has multiple captions. To simplify the dataset, we’ll split these so that each caption becomes a unique sample.
def separate_captions_into_unique_samples(batch):
new_images = []
new_bboxes = []
new_captions = []
for image, bbox, captions in zip(batch["image"], batch["bbox"], batch["captions"]):
for caption in captions:
new_images.append(image)
new_bboxes.append(bbox)
new_captions.append(caption)
return {
"image": new_images,
"bbox": new_bboxes,
"caption": new_captions,
}
filtered_dataset = filtered_dataset.map(
separate_captions_into_unique_samples,
batched=True,
batch_size=100,
num_proc=4,
remove_columns=filtered_dataset.column_names,
)
Now that everything is prepared, let’s take a look at an example!
filtered_dataset[20]["caption"]
filtered_dataset[20]["bbox"]
>>> filtered_dataset[20]["image"]
2.4 Display a Sample with Bounding Boxes
Our dataset preparation is complete. Now, let’s visualize the bounding boxes on an image from a sample.
To do this, we’ll create an auxiliary function that we can reuse throughout the recipe.
We’ll use the supervision library to assist with displaying the bounding boxes.
labels = [(filtered_dataset[20]["caption"], filtered_dataset[20]["bbox"])]
>>> import supervision as sv
>>> import numpy as np
>>> def get_annotated_image(image, parsed_labels):
... if not parsed_labels:
... return image
... xyxys = []
... labels = []
... for label, bbox in parsed_labels:
... xyxys.append(bbox)
... labels.append(label)
... detections = sv.Detections(xyxy=np.array(xyxys))
... bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
... label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
... annotated_image = bounding_box_annotator.annotate(scene=image, detections=detections)
... annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
... return annotated_image
>>> annotated_image = get_annotated_image(filtered_dataset[20]["image"], labels)
>>> annotated_image
Great! We can now see the grounding caption associated with each bounding box.
2.5 Divide the Dataset
Our dataset is ready, but before we proceed, let’s split it into training and validation sets for proper model evaluation.
split_dataset = filtered_dataset.train_test_split(test_size=0.2, seed=42, shuffle=False)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]
train_dataset, val_dataset
3. Check the Pretrained Model with the Dataset
As mentioned earlier, we’ll be using PaliGemma 2 as our model since it already includes object detection capabilities, which simplifies our workflow.
If we were using a Vision-Language Model (VLM) without built-in object detection capabilities, we would likely need to train it first to acquire them.
For more on this, check out our project on “Fine-tuning Gemma 3 for Object Detection” that covers this training process in detail.
Now, let’s load the model and processor. We’ll use the pretrained model google/paligemma2-3b-pt-448, which is not fine-tuned for conversational tasks.
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
import torch
model_id = "google/paligemma2-3b-pt-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id, use_fast=True)
3.1 Inference on One Sample
Let’s evaluate the current performance of the model on a single image and caption.
image = train_dataset[20]["image"]
caption = train_dataset[20]["caption"]
Since our model is not an instruct model, the input should be formatted as follows:
<image>detect [CAPTION]
Here, <image>
represents the image token, followed by the keyword detect
to specify the object detection task, and then the caption describing what to detect.
This format will produce a specific output, as we will see next.
>>> prompt = f"<image>detect {caption}"
>>> model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
>>> input_len = model_inputs["input_ids"].shape[-1]
>>> with torch.inference_mode():
... generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
... generation = generation[0][input_len:]
... output = processor.decode(generation, skip_special_tokens=True)
... print(output)
middle vase ; middle vase ; middle vase
We can see that the model generates location tokens in a special format like <locXXXX>...
, followed by the detected category. Each detection is separated by a ;
.
These location tokens follow the PaliGemma format, which is specific to the model and relative to the input size—448x448
in this case, as indicated by the model name.
To display the detections correctly, we need to convert these tokens back to a usable format. Let’s create an auxiliary function to handle this conversion:
import re
# https://github.com/ariG23498/gemma3-object-detection/blob/main/utils.py#L17 thanks to Aritra Roy Gosthipaty
def parse_paligemma_labels(label, width, height):
predictions = label.strip().split(";")
results = []
for pred in predictions:
pred = pred.strip()
if not pred:
continue
loc_pattern = r"<loc(\d{4})>"
locations = [int(loc) for loc in re.findall(loc_pattern, pred)]
if len(locations) != 4:
continue
category = pred.split(">")[-1].strip()
y1_norm, x1_norm, y2_norm, x2_norm = locations
x1 = (x1_norm / 1024) * width
y1 = (y1_norm / 1024) * height
x2 = (x2_norm / 1024) * width
y2 = (y2_norm / 1024) * height
results.append((category, [x1, y1, x2, y2]))
return results
Now, we can use this function to parse the PaliGemma labels into the common COCO format.
width, height = image.size parsed_labels = parse_paligemma_labels(output, width, height) parsed_labels
Next, we can use the previous function to retrieve the image.
Let’s display it along with the parsed bounding boxes!
annotated_image = get_annotated_image(image, parsed_labels)
>>> annotated_image
We can see that the model performs well on object detection, but it struggles a bit with grounding.
For example, it labels all three vases as the “middle vase” instead of just one.
Let’s work on improving that! 🙂
4. Fine-Tuning the Model Using the Dataset with LoRA and TRL
To fine-tune the Vision-Language Model (VLM), we will leverage LoRA and TRL.
Let’s start by configuring LoRA:
>>> from peft import LoraConfig, get_peft_model
>>> target_modules = ["q_proj", "v_proj", "fc1", "fc2", "linear", "gate_proj", "up_proj", "down_proj"]
>>> # Configure LoRA
>>> peft_config = LoraConfig(
... lora_alpha=16,
... lora_dropout=0.05,
... r=8,
... bias="none",
... target_modules=target_modules,
... task_type="CAUSAL_LM",
... )
>>> # Apply PEFT model adaptation
>>> peft_model = get_peft_model(model, peft_config)
>>> # Print trainable parameters
>>> peft_model.print_trainable_parameters()
trainable params: 12,165,888 || all params: 3,045,293,040 || trainable%: 0.3995
Next, let’s configure the SFT training pipeline from TRL.
This pipeline simplifies the training process by abstracting much of the underlying complexity and managing it for us.
from trl import SFTConfig
training_args = SFTConfig(
output_dir="paligemma2-3b-pt-448-od-grounding",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=4,
gradient_checkpointing=False,
learning_rate=1e-05,
num_train_epochs=2,
logging_steps=10,
eval_steps=100,
eval_strategy="steps",
save_steps=10,
bf16=True,
report_to=["tensorboard"],
dataset_kwargs={"skip_prepare_dataset": True},
remove_unused_columns=False,
push_to_hub=True,
dataloader_pin_memory=False,
label_names=["labels"],
)
We’re almost ready!
Next, we’ll define a few auxiliary functions to handle object detection within the collator.
These functions are straightforward and self-explanatory.
def coco_to_xyxy(coco_bbox):
x, y, width, height = coco_bbox
x1, y1 = x, y
x2, y2 = x + width, y + height
return [x1, y1, x2, y2]
def convert_to_detection_string(bboxs, image_width, image_height, category):
def format_location(value, max_value):
return f"<loc{int(round(value * 1024 / max_value)):04}>"
detection_strings = []
for bbox in bboxs:
x1, y1, x2, y2 = coco_to_xyxy(bbox)
locs = [
format_location(y1, image_height),
format_location(x1, image_width),
format_location(y2, image_height),
format_location(x2, image_width),
]
detection_string = "".join(locs) + f" {category}"
detection_strings.append(detection_string)
return " ; ".join(detection_strings)
def format_objects(example):
height = example["height"]
width = example["width"]
bboxs = example["bbox"]
category = example["caption"][0]
formatted_objects = convert_to_detection_string(bboxs, width, height, category)
return {"label_for_paligemma": formatted_objects}
Since we’re fine-tuning a VLM, we can also incorporate data augmentation.
In our case, we’ll handle image resizing—which is mandatory to ensure consistent input size—because the model expects images of 448x448
.
For reference, we’ve included a couple of possible augmentations commented out.
import albumentations as A
resize_size = 448
augmentations = A.Compose(
[
A.Resize(height=resize_size, width=resize_size),
# A.HorizontalFlip(p=0.5),
# A.ColorJitter(p=0.2),
],
bbox_params=A.BboxParams(format="coco", label_fields=["category_ids"], filter_invalid_bboxes=True),
)
Now, let’s create the collate function that prepares batches for input to the VLM.
In this step, we need to carefully handle the data augmentation process to ensure consistency and correctness.
from functools import partial
# Create a data collator to encode text and image pairs
def collate_fn(examples, transform=None):
images = []
prompts = []
suffixes = []
for sample in examples:
if transform:
transformed = transform(
image=np.array(sample["image"]), bboxes=[sample["bbox"]], category_ids=[sample["caption"]]
)
sample["image"] = transformed["image"]
sample["bbox"] = transformed["bboxes"]
sample["caption"] = transformed["category_ids"]
sample["height"] = sample["image"].shape[0]
sample["width"] = sample["image"].shape[1]
sample["label_for_paligemma"] = format_objects(sample)["label_for_paligemma"]
images.append([sample["image"]])
prompts.append(f"<image>Detect {sample['caption']}.")
suffixes.append(sample["label_for_paligemma"])
batch = processor(images=images, text=prompts, suffix=suffixes, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone() # Clone input IDs for labels
image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
batch["labels"] = labels
batch["pixel_values"] = batch["pixel_values"].to(model.device)
return batch
train_collate_fn = partial(collate_fn, transform=augmentations)
Finally, we can instantiate the SFTTrainer
and start training our model!
from trl import SFTTrainer
trainer = SFTTrainer(
model=peft_model,
args=training_args,
data_collator=train_collate_fn,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
trainer.train()
Let’s save the trained model and results to the Hugging Face Hub.
processor.save_pretrained(training_args.output_dir) trainer.save_model(training_args.output_dir) trainer.push_to_hub()
5. Test the Fine-Tuned Model
We have fine-tuned our model for grounded object detection. As the final step, let’s test its capabilities on a sample from the test set.
Model: sergiopaniego/paligemma2-3b-pt-448-od-grounding
Let’s instantiate our model using the fine-tuned checkpoint:
trained_model_id = "sergiopaniego/paligemma2-3b-pt-448-od-grounding"
model_id = "google/paligemma2-3b-pt-448"
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
from peft import PeftModel
import torch
base_model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
trained_model = PeftModel.from_pretrained(base_model, trained_model_id).eval()
trained_processor = PaliGemmaProcessor.from_pretrained(model_id, use_fast=True)
5.1 Test on a Training Sample
Let’s start by testing on one of the training images.
This gives us an initial sense of how the training went, but keep in mind it can be a bit misleading since the model has already seen this sample during training.
For this test, we’ll use the example we presented earlier to check if the model can now perform inference correctly.
image = train_dataset[20]["image"]
caption = train_dataset[20]["caption"]
>>> prompt = f"<image>detect {caption}"
>>> model_inputs = (
... trained_processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(trained_model.device)
... )
>>> input_len = model_inputs["input_ids"].shape[-1]
>>> with torch.inference_mode():
... generation = trained_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
... generation = generation[0][input_len:]
... output = trained_processor.decode(generation, skip_special_tokens=True)
... print(output)
middle vase
width, height = image.size parsed_labels = parse_paligemma_labels(output, width, height) parsed_labels
annotated_image = get_annotated_image(image, parsed_labels)
Let’s see if the fine-tuning was successful… 🥁
>>> annotated_image
Nice! The model is now able to correctly recognize the “middle vase”.
5.3 Test Against a Validation Sample
Finally, let’s evaluate the model’s capabilities on a validation sample to properly assess whether it has learned both grounding and object detection.
image = val_dataset[13]["image"]
caption = val_dataset[13]["caption"]
caption
>>> prompt = f"<image>detect {caption}"
>>> model_inputs = (
... trained_processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(trained_model.device)
... )
>>> input_len = model_inputs["input_ids"].shape[-1]
>>> with torch.inference_mode():
... generation = trained_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
... generation = generation[0][input_len:]
... output = trained_processor.decode(generation, skip_special_tokens=True)
... print(output)
darker bear
width, height = image.size parsed_labels = parse_paligemma_labels(output, width, height) parsed_labels
annotated_image = get_annotated_image(image, parsed_labels)
Let’s check… 🥁
>>> annotated_image
It works! Our model is able to correctly identify the “darker bear” in the image and avoids generating multiple detections for each bear.
Keep in mind that our training was light—using only a subset of the dataset—and the training configuration can be further optimized. We leave those improvements for you to explore!
6. Continuing the Learning Journey 🧑🎓️
To further enhance your understanding and skills, check out these valuable resources:
- Fine-tuning Grounding DINO — LearnOpenCV
- RefCOCO Dataset — Papers with Code
- Fine-tune PaliGemma — GitHub
- Fine-tuning Gemma 3 for Object Detection — GitHub
- VLM Object Understanding — Hugging Face Space
- How Well Does GPT-4o Understand Vision? Evaluating Multimodal Foundation Models on Standard Computer Vision Tasks paper
- Vision Language Models (Better, Faster, Stronger) blog
- Check out other multimodal recipes in the HF Open-Source AI Cookbook
Feel free to explore these to deepen your knowledge and keep pushing the boundaries!
< > Update on GitHub