Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Open In Colab

Fine tuning a VLM for Object Detection Grounding using TRL

Authored by: Sergio Paniego

🚨 WARNING: This notebook is resource-intensive and requires substantial computational power. If you’re running it in Colab, it will utilize an A100 GPU.

🔍 What You’ll Learn

In this recipe, we’ll demonstrate how to fine-tune a Vision-Language Model (VLM) for object detection grounding using TRL.

Traditionally, object detection involves identifying a predefined set of classes (e.g., “car”, “person”, “dog”) within an image. However, this paradigm shifted with models like Grounding DINO, GLIP, or OWL-ViT, which introduced open-ended object detection—enabling models to detect any class described in natural language.

Grounding goes a step further by adding contextual understanding. Instead of just detecting a “car”, grounded detection can locate the “car on the left”, or the “red car behind the tree”. This provides a more nuanced and powerful approach to object detection.

In this recipe, we’ll walk through how to fine-tune a VLM for this task. Specifically, we’ll use PaliGemma 2, a Vision-Language Model developed by Google that supports object detection out of the box. While not all VLMs offer detection capabilities by default, the concepts and steps in this notebook can be adapted for models without built-in object detection as well.

To train our model, we’ll use RefCOCO, an extension of the popular COCO dataset, designed specifically for referring expression comprehension—that is, combining object detection with grounding through natural language.

This recipe also builds upon my recent release of this Space, which lets you compare different VLMs on object understanding tasks such as object detection, keypoint detection, and more.

📚 Additional Resources
At the end of this notebook, you’ll find extra resources if you’re interested in exploring the topic further.

fine_tuning_vlm_object_detection_grounding.png

1. Install dependencies

Let’s start by installing the required dependencies:

!pip install -Uq transformers datasets trl supervision albumentations

We’ll log in to our Hugging Face account to access gated models and save our trained checkpoints.
You’ll need an access token 🗝️.

from huggingface_hub import notebook_login

notebook_login()

2. 📁 Load Dataset

For this example, we’ll use RefCOCO, a dataset that includes grounded object detection annotations—enabling more robust and context-aware detection.

To keep things simple and efficient, we’ll work with a subset of the dataset.

RefCoco-0000005165-c3b501a3.jpeg

from datasets import load_dataset

refcoco_dataset = load_dataset("jxu124/refcoco", split="train[:5%]")

After loading it, let’s see what’s inside:

refcoco_dataset

We can see that the dataset contains useful information such as the bbox and captions columns. In this case, bboxes follow a xyxy format.

However, the image itself isn’t directly accessible from these fields. For more details about the image source, we can inspect the raw_image_info column.

refcoco_dataset[13]["raw_image_info"]

2.1 🖼️ Add Images to the Dataset

While we could link each example to the corresponding image in the COCO dataset, we’ll simplify the process by downloading the images directly from Flickr.

However, this approach may result in some missing images, so we’ll need to handle those cases accordingly.

import json
import requests
from PIL import Image
from io import BytesIO


def add_image(example):
    try:
        raw_info = json.loads(example["raw_image_info"])
        url = raw_info.get("flickr_url", None)
        if url:
            response = requests.get(url, timeout=10)
            image = Image.open(BytesIO(response.content)).convert("RGB")
            example["image"] = image
        else:
            example["image"] = None
    except Exception as e:
        print(f"Error loading image: {e}")
        example["image"] = None
    return example


refcoco_dataset_with_images = refcoco_dataset.map(add_image, desc="Adding image from flickr", num_proc=16)

Awesome! Our images are now downloaded and ready to go.

refcoco_dataset_with_images

Next, let’s filter the dataset to include only samples that have an associated image:

filtered_dataset = refcoco_dataset_with_images.filter(
    lambda example: example["image"] is not None, desc="Removing failed image downloads"
)

2.2 Remove Unneeded Columns

filtered_dataset

The dataset contains many columns that we won’t need for this task.
Let’s simplify it by keeping only the 'bbox', 'captions', and 'image' columns.

filtered_dataset = filtered_dataset.remove_columns(
    [
        "sent_ids",
        "file_name",
        "ann_id",
        "ref_id",
        "image_id",
        "split",
        "sentences",
        "category_id",
        "raw_anns",
        "raw_image_info",
        "raw_sentences",
        "image_path",
        "global_image_id",
        "anns_id",
    ]
)

It looks much better now!

filtered_dataset

2.3 Separate Captions into Unique Samples

One final step: each sample currently has multiple captions. To simplify the dataset, we’ll split these so that each caption becomes a unique sample.

def separate_captions_into_unique_samples(batch):
    new_images = []
    new_bboxes = []
    new_captions = []

    for image, bbox, captions in zip(batch["image"], batch["bbox"], batch["captions"]):
        for caption in captions:
            new_images.append(image)
            new_bboxes.append(bbox)
            new_captions.append(caption)

    return {
        "image": new_images,
        "bbox": new_bboxes,
        "caption": new_captions,
    }


filtered_dataset = filtered_dataset.map(
    separate_captions_into_unique_samples,
    batched=True,
    batch_size=100,
    num_proc=4,
    remove_columns=filtered_dataset.column_names,
)

Now that everything is prepared, let’s take a look at an example!

filtered_dataset[20]["caption"]
filtered_dataset[20]["bbox"]
>>> filtered_dataset[20]["image"]

2.4 Display a Sample with Bounding Boxes

Our dataset preparation is complete. Now, let’s visualize the bounding boxes on an image from a sample.
To do this, we’ll create an auxiliary function that we can reuse throughout the recipe.

We’ll use the supervision library to assist with displaying the bounding boxes.

labels = [(filtered_dataset[20]["caption"], filtered_dataset[20]["bbox"])]
>>> import supervision as sv
>>> import numpy as np


>>> def get_annotated_image(image, parsed_labels):
...     if not parsed_labels:
...         return image

...     xyxys = []
...     labels = []

...     for label, bbox in parsed_labels:
...         xyxys.append(bbox)
...         labels.append(label)

...     detections = sv.Detections(xyxy=np.array(xyxys))

...     bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
...     label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)

...     annotated_image = bounding_box_annotator.annotate(scene=image, detections=detections)
...     annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)

...     return annotated_image


>>> annotated_image = get_annotated_image(filtered_dataset[20]["image"], labels)
>>> annotated_image

Great! We can now see the grounding caption associated with each bounding box.

2.5 Divide the Dataset

Our dataset is ready, but before we proceed, let’s split it into training and validation sets for proper model evaluation.

split_dataset = filtered_dataset.train_test_split(test_size=0.2, seed=42, shuffle=False)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]
train_dataset, val_dataset

3. Check the Pretrained Model with the Dataset

As mentioned earlier, we’ll be using PaliGemma 2 as our model since it already includes object detection capabilities, which simplifies our workflow.

If we were using a Vision-Language Model (VLM) without built-in object detection capabilities, we would likely need to train it first to acquire them.

For more on this, check out our project on “Fine-tuning Gemma 3 for Object Detection” that covers this training process in detail.

Now, let’s load the model and processor. We’ll use the pretrained model google/paligemma2-3b-pt-448, which is not fine-tuned for conversational tasks.

paligemma2_arch.png

from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
import torch

model_id = "google/paligemma2-3b-pt-448"

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id, use_fast=True)

3.1 Inference on One Sample

Let’s evaluate the current performance of the model on a single image and caption.

image = train_dataset[20]["image"]
caption = train_dataset[20]["caption"]

Since our model is not an instruct model, the input should be formatted as follows:

<image>detect [CAPTION]

Here, <image> represents the image token, followed by the keyword detect to specify the object detection task, and then the caption describing what to detect.

This format will produce a specific output, as we will see next.

>>> prompt = f"<image>detect {caption}"
>>> model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
>>> input_len = model_inputs["input_ids"].shape[-1]

>>> with torch.inference_mode():
...     generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
...     generation = generation[0][input_len:]
...     output = processor.decode(generation, skip_special_tokens=True)
...     print(output)
 middle vase ;  middle vase ;  middle vase

We can see that the model generates location tokens in a special format like <locXXXX>..., followed by the detected category. Each detection is separated by a ;.

These location tokens follow the PaliGemma format, which is specific to the model and relative to the input size—448x448 in this case, as indicated by the model name.

To display the detections correctly, we need to convert these tokens back to a usable format. Let’s create an auxiliary function to handle this conversion:

import re


# https://github.com/ariG23498/gemma3-object-detection/blob/main/utils.py#L17 thanks to Aritra Roy Gosthipaty
def parse_paligemma_labels(label, width, height):
    predictions = label.strip().split(";")
    results = []

    for pred in predictions:
        pred = pred.strip()
        if not pred:
            continue

        loc_pattern = r"<loc(\d{4})>"
        locations = [int(loc) for loc in re.findall(loc_pattern, pred)]

        if len(locations) != 4:
            continue

        category = pred.split(">")[-1].strip()

        y1_norm, x1_norm, y2_norm, x2_norm = locations
        x1 = (x1_norm / 1024) * width
        y1 = (y1_norm / 1024) * height
        x2 = (x2_norm / 1024) * width
        y2 = (y2_norm / 1024) * height

        results.append((category, [x1, y1, x2, y2]))

    return results

Now, we can use this function to parse the PaliGemma labels into the common COCO format.

width, height = image.size
parsed_labels = parse_paligemma_labels(output, width, height)
parsed_labels

Next, we can use the previous function to retrieve the image.
Let’s display it along with the parsed bounding boxes!

annotated_image = get_annotated_image(image, parsed_labels)
>>> annotated_image

We can see that the model performs well on object detection, but it struggles a bit with grounding.
For example, it labels all three vases as the “middle vase” instead of just one.

Let’s work on improving that! 🙂

4. Fine-Tuning the Model Using the Dataset with LoRA and TRL

To fine-tune the Vision-Language Model (VLM), we will leverage LoRA and TRL.

Let’s start by configuring LoRA:

>>> from peft import LoraConfig, get_peft_model

>>> target_modules = ["q_proj", "v_proj", "fc1", "fc2", "linear", "gate_proj", "up_proj", "down_proj"]

>>> # Configure LoRA
>>> peft_config = LoraConfig(
...     lora_alpha=16,
...     lora_dropout=0.05,
...     r=8,
...     bias="none",
...     target_modules=target_modules,
...     task_type="CAUSAL_LM",
... )

>>> # Apply PEFT model adaptation
>>> peft_model = get_peft_model(model, peft_config)

>>> # Print trainable parameters
>>> peft_model.print_trainable_parameters()
trainable params: 12,165,888 || all params: 3,045,293,040 || trainable%: 0.3995

Next, let’s configure the SFT training pipeline from TRL.
This pipeline simplifies the training process by abstracting much of the underlying complexity and managing it for us.

from trl import SFTConfig

training_args = SFTConfig(
    output_dir="paligemma2-3b-pt-448-od-grounding",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=False,
    learning_rate=1e-05,
    num_train_epochs=2,
    logging_steps=10,
    eval_steps=100,
    eval_strategy="steps",
    save_steps=10,
    bf16=True,
    report_to=["tensorboard"],
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    push_to_hub=True,
    dataloader_pin_memory=False,
    label_names=["labels"],
)

We’re almost ready!
Next, we’ll define a few auxiliary functions to handle object detection within the collator.
These functions are straightforward and self-explanatory.

def coco_to_xyxy(coco_bbox):
    x, y, width, height = coco_bbox
    x1, y1 = x, y
    x2, y2 = x + width, y + height
    return [x1, y1, x2, y2]


def convert_to_detection_string(bboxs, image_width, image_height, category):
    def format_location(value, max_value):
        return f"<loc{int(round(value * 1024 / max_value)):04}>"

    detection_strings = []
    for bbox in bboxs:
        x1, y1, x2, y2 = coco_to_xyxy(bbox)
        locs = [
            format_location(y1, image_height),
            format_location(x1, image_width),
            format_location(y2, image_height),
            format_location(x2, image_width),
        ]
        detection_string = "".join(locs) + f" {category}"
        detection_strings.append(detection_string)
    return " ; ".join(detection_strings)


def format_objects(example):
    height = example["height"]
    width = example["width"]
    bboxs = example["bbox"]
    category = example["caption"][0]
    formatted_objects = convert_to_detection_string(bboxs, width, height, category)
    return {"label_for_paligemma": formatted_objects}

Since we’re fine-tuning a VLM, we can also incorporate data augmentation.

In our case, we’ll handle image resizing—which is mandatory to ensure consistent input size—because the model expects images of 448x448.

For reference, we’ve included a couple of possible augmentations commented out.

import albumentations as A

resize_size = 448

augmentations = A.Compose(
    [
        A.Resize(height=resize_size, width=resize_size),
        # A.HorizontalFlip(p=0.5),
        # A.ColorJitter(p=0.2),
    ],
    bbox_params=A.BboxParams(format="coco", label_fields=["category_ids"], filter_invalid_bboxes=True),
)

Now, let’s create the collate function that prepares batches for input to the VLM.

In this step, we need to carefully handle the data augmentation process to ensure consistency and correctness.

from functools import partial


# Create a data collator to encode text and image pairs
def collate_fn(examples, transform=None):
    images = []
    prompts = []
    suffixes = []
    for sample in examples:
        if transform:
            transformed = transform(
                image=np.array(sample["image"]), bboxes=[sample["bbox"]], category_ids=[sample["caption"]]
            )
            sample["image"] = transformed["image"]
            sample["bbox"] = transformed["bboxes"]
            sample["caption"] = transformed["category_ids"]
            sample["height"] = sample["image"].shape[0]
            sample["width"] = sample["image"].shape[1]
            sample["label_for_paligemma"] = format_objects(sample)["label_for_paligemma"]
        images.append([sample["image"]])
        prompts.append(f"<image>Detect {sample['caption']}.")
        suffixes.append(sample["label_for_paligemma"])
    batch = processor(images=images, text=prompts, suffix=suffixes, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    image_token_id = processor.tokenizer.additional_special_tokens_ids[
        processor.tokenizer.additional_special_tokens.index("<image>")
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    batch["labels"] = labels

    batch["pixel_values"] = batch["pixel_values"].to(model.device)
    return batch


train_collate_fn = partial(collate_fn, transform=augmentations)

Finally, we can instantiate the SFTTrainer and start training our model!

from trl import SFTTrainer

trainer = SFTTrainer(
    model=peft_model,
    args=training_args,
    data_collator=train_collate_fn,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)
trainer.train()

Let’s save the trained model and results to the Hugging Face Hub.

processor.save_pretrained(training_args.output_dir)
trainer.save_model(training_args.output_dir)
trainer.push_to_hub()

5. Test the Fine-Tuned Model

We have fine-tuned our model for grounded object detection. As the final step, let’s test its capabilities on a sample from the test set.

Model: sergiopaniego/paligemma2-3b-pt-448-od-grounding

Let’s instantiate our model using the fine-tuned checkpoint:

trained_model_id = "sergiopaniego/paligemma2-3b-pt-448-od-grounding"
model_id = "google/paligemma2-3b-pt-448"
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from peft import PeftModel
import torch

base_model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
trained_model = PeftModel.from_pretrained(base_model, trained_model_id).eval()

trained_processor = PaliGemmaProcessor.from_pretrained(model_id, use_fast=True)

5.1 Test on a Training Sample

Let’s start by testing on one of the training images.
This gives us an initial sense of how the training went, but keep in mind it can be a bit misleading since the model has already seen this sample during training.

For this test, we’ll use the example we presented earlier to check if the model can now perform inference correctly.

image = train_dataset[20]["image"]
caption = train_dataset[20]["caption"]
>>> prompt = f"<image>detect {caption}"
>>> model_inputs = (
...     trained_processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(trained_model.device)
... )
>>> input_len = model_inputs["input_ids"].shape[-1]

>>> with torch.inference_mode():
...     generation = trained_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
...     generation = generation[0][input_len:]
...     output = trained_processor.decode(generation, skip_special_tokens=True)
...     print(output)
 middle vase
width, height = image.size
parsed_labels = parse_paligemma_labels(output, width, height)
parsed_labels
annotated_image = get_annotated_image(image, parsed_labels)

Let’s see if the fine-tuning was successful… 🥁

>>> annotated_image

Nice! The model is now able to correctly recognize the “middle vase”.

5.3 Test Against a Validation Sample

Finally, let’s evaluate the model’s capabilities on a validation sample to properly assess whether it has learned both grounding and object detection.

image = val_dataset[13]["image"]
caption = val_dataset[13]["caption"]
caption
>>> prompt = f"<image>detect {caption}"
>>> model_inputs = (
...     trained_processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(trained_model.device)
... )
>>> input_len = model_inputs["input_ids"].shape[-1]

>>> with torch.inference_mode():
...     generation = trained_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
...     generation = generation[0][input_len:]
...     output = trained_processor.decode(generation, skip_special_tokens=True)
...     print(output)
 darker bear
width, height = image.size
parsed_labels = parse_paligemma_labels(output, width, height)
parsed_labels
annotated_image = get_annotated_image(image, parsed_labels)

Let’s check… 🥁

>>> annotated_image

It works! Our model is able to correctly identify the “darker bear” in the image and avoids generating multiple detections for each bear.

Keep in mind that our training was light—using only a subset of the dataset—and the training configuration can be further optimized. We leave those improvements for you to explore!

6. Continuing the Learning Journey 🧑‍🎓️

To further enhance your understanding and skills, check out these valuable resources:

Feel free to explore these to deepen your knowledge and keep pushing the boundaries!

< > Update on GitHub