paligemma-3b-mix-448-ft-TableDetection
This model is a mixed precision fine-tuned version of google/paligemma-3b-mix-448 on ucsahin/pubtables-detection-1500-samples dataset. It achieves the following results on the evaluation set:
- Loss: 1.3544
Model Details
- This model is a multimodal language model fine-tuned for the task of detecting tables in images given textual prompts. The model utilizes a combination of image and text inputs to predict bounding boxes around tables within the provided images.
- The primary purpose of this model is to assist in automating the process of table detection within images. It can be utilized in various applications such as document processing, data extraction, and image analysis, where identifying tables within images is essential.
Inputs:
- Image: The model requires an image containing one or more tables as input. The image should be in a standard format such as JPEG or PNG.
- Text Prompt: Additionally, a text prompt is required to guide the model's attention towards the task of table detection. The prompt should clearly indicate the desired action. Please use "detect table" as your text prompt.
Outputs:
- Bounding Boxes: The model outputs the location for the bounding box coordinates in the form of special <loc[value]> tokens, where value is a number that represents a normalized coordinate. Each detection is represented by four location coordinates in the order y_min, x_min, y_max, x_max, followed by the label that was detected in that box. To convert values to coordinates, you first need to divide the numbers by 1024, then multiply y by the image height and x by its width. This will give you the coordinates of the bounding boxes, relative to the original image size. If everything goes smoothly, the model will output a text similar to "<loc[value]><loc[value]><loc[value]><loc[value]> table; <loc[value]><loc[value]><loc[value]><loc[value]> table" depending on the number of tables detected in the image. Then, you can use the following script to convert the text output into PASCAL VOC formatted bounding boxes.
import re
def post_process(bbox_text, image_width, image_height):
loc_values_str = [bbox.strip() for bbox in bbox_text.split(";")]
converted_bboxes = []
for loc_value_str in loc_values_str:
loc_values = re.findall(r'<loc(\d+)>', loc_value_str)
loc_values = [int(x) for x in loc_values]
loc_values = loc_values[:4]
loc_values = [value/1024 for value in loc_values]
# convert to (xmin, ymin, xmax, ymax)
loc_values = [
int(loc_values[1]*image_width), int(loc_values[0]*image_height),
int(loc_values[3]*image_width), int(loc_values[2]*image_height),
]
converted_bboxes.append(loc_values)
return converted_bboxes
How to Get Started with the Model
In Transformers, you can load the model as follows:
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"
device = "cuda:0"
dtype = torch.bfloat16
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device
)
processor = PaliGemmaProcessor.from_pretrained(model_id)
For inference, you can use the following:
# # Instruct the model to detect tables
prompt = "detect table"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
generation = generation[0][input_len:]
bbox_text = processor.decode(generation, skip_special_tokens=True)
print(bbox_text)
Warning: You can also load a quantized 4-bit or 8-bit model using bitsandbytes
. Beware though that the model can generate outputs that can require further post-processing for example five locations tags "<loc[value]>" instead of four, and different labels other than "table". The provided post-processing script should handle the first case.
Use the following to load the 4-bit quantized model:
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, BitsAndBytesConfig
import torch
model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"
device = "cuda:0"
dtype = torch.bfloat16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype
)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
quantization_config=bnb_config
)
processor = PaliGemmaProcessor.from_pretrained(model_id)
Bias, Risks, and Limitations
Please refer to google/paligemma-3b-mix-448 for bias, risks and limitations.
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.0001
- train_batch_size: 4
- eval_batch_size: 4
- seed: 42
- gradient_accumulation_steps: 4
- bf16: True mixed precision
- total_train_batch_size: 16
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_steps: 5
- num_epochs: 3
Training results
Training Loss | Epoch | Step | Validation Loss |
---|---|---|---|
2.957 | 0.1775 | 15 | 2.1300 |
1.9656 | 0.3550 | 30 | 1.8421 |
1.6716 | 0.5325 | 45 | 1.6898 |
1.5514 | 0.7101 | 60 | 1.5803 |
1.5851 | 0.8876 | 75 | 1.5271 |
1.4134 | 1.0651 | 90 | 1.4771 |
1.3566 | 1.2426 | 105 | 1.4528 |
1.3093 | 1.4201 | 120 | 1.4227 |
1.2897 | 1.5976 | 135 | 1.4115 |
1.256 | 1.7751 | 150 | 1.4007 |
1.2666 | 1.9527 | 165 | 1.3678 |
1.2213 | 2.1302 | 180 | 1.3744 |
1.0999 | 2.3077 | 195 | 1.3633 |
1.1931 | 2.4852 | 210 | 1.3606 |
1.0722 | 2.6627 | 225 | 1.3619 |
1.1485 | 2.8402 | 240 | 1.3544 |
Framework versions
- PEFT 0.11.1
- Transformers 4.42.0.dev0
- Pytorch 2.3.0+cu121
- Datasets 2.19.1
- Tokenizers 0.19.1
- Downloads last month
- 12