Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import PIL.Image | |
import transformers | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import supervision as sv | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
BOX_ANNOTATOR = sv.BoxAnnotator() | |
LABEL_ANNOTATOR = sv.LabelAnnotator() | |
MASK_ANNOTATOR = sv.MaskAnnotator() | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_id = "google/paligemma2-3b-pt-448" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE) | |
processor = PaliGemmaProcessor.from_pretrained(model_id) | |
def process_image(input_image,input_text,class_names): | |
class_list = class_names.split(',') | |
cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
model_inputs = processor(text=input_text, images=input_image, return_tensors="pt").to(torch.bfloat16).to(model.device) | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) | |
generation = generation[0][input_len:] | |
result = processor.decode(generation, skip_special_tokens=True) | |
detections = sv.Detections.from_lmm( | |
sv.LMM.PALIGEMMA, | |
result, | |
resolution_wh=(input_image.width, input_image.height), | |
classes=class_list | |
) | |
annotated_image = BOX_ANNOTATOR.annotate( | |
scene=cv_image.copy(), | |
detections=detections | |
) | |
annotated_image = LABEL_ANNOTATOR.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = MASK_ANNOTATOR.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) | |
annotated_image = Image.fromarray(annotated_image) | |
return annotated_image, result | |
app = gr.Interface( | |
fn=process_image, | |
inputs=[gr.Image(type="pil"),gr.Textbox(lines=2, placeholder="Enter text here..."), | |
gr.Textbox(lines=1, placeholder="Enter class names separated by commas...")], | |
outputs=[gr.Image(type="pil"), gr.Textbox()], | |
title="PaliGemma2 Image Detection with Supervision", | |
description="Detect objects in an image using PaliGemma2 model." | |
) | |
if __name__ == "__main__": | |
app.launch() |