onuralpszr's picture
feat: ✨ initial commit added
344bc31 verified
raw
history blame
2.39 kB
import os
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import supervision as sv
import cv2
import numpy as np
from PIL import Image
import gradio as gr
BOX_ANNOTATOR = sv.BoxAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "google/paligemma2-3b-pt-448"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE)
processor = PaliGemmaProcessor.from_pretrained(model_id)
def process_image(input_image,input_text,class_names):
class_list = class_names.split(',')
cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
model_inputs = processor(text=input_text, images=input_image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
result = processor.decode(generation, skip_special_tokens=True)
detections = sv.Detections.from_lmm(
sv.LMM.PALIGEMMA,
result,
resolution_wh=(input_image.width, input_image.height),
classes=class_list
)
annotated_image = BOX_ANNOTATOR.annotate(
scene=cv_image.copy(),
detections=detections
)
annotated_image = LABEL_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = MASK_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
annotated_image = Image.fromarray(annotated_image)
return annotated_image, result
app = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil"),gr.Textbox(lines=2, placeholder="Enter text here..."),
gr.Textbox(lines=1, placeholder="Enter class names separated by commas...")],
outputs=[gr.Image(type="pil"), gr.Textbox()],
title="PaliGemma2 Image Detection with Supervision",
description="Detect objects in an image using PaliGemma2 model."
)
if __name__ == "__main__":
app.launch()