Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,410 Bytes
344bc31 990acce 344bc31 49d986a 8d92093 344bc31 49d986a 344bc31 49d986a 344bc31 49d986a 344bc31 49d986a 344bc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import supervision as sv
import cv2
import numpy as np
from PIL import Image
import gradio as gr
import spaces
BOX_ANNOTATOR = sv.BoxAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "google/paligemma2-3b-pt-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE)
processor = PaliGemmaProcessor.from_pretrained(model_id)
@spaces.GPU
def process_image(input_image,input_text,class_names):
class_list = class_names.split(',')
cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
model_inputs = processor(text=input_text, images=input_image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
result = processor.decode(generation, skip_special_tokens=True)
detections = sv.Detections.from_lmm(
sv.LMM.PALIGEMMA,
result,
resolution_wh=(input_image.width, input_image.height),
classes=class_list
)
annotated_image = BOX_ANNOTATOR.annotate(
scene=cv_image.copy(),
detections=detections
)
annotated_image = LABEL_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = MASK_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
annotated_image = Image.fromarray(annotated_image)
return annotated_image, result
with gr.Blocks() as app:
gr.Markdown( """
## PaliGemma 2 Detection with Supervision - Demo \n\n
<div style="display: flex; gap: 10px;">
<a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md">
<img src="https://img.shields.io/badge/Github-100000?style=flat&logo=github&logoColor=white" alt="Github">
</a>
<a href="https://huggingface.co/blog/paligemma">
<img src="https://img.shields.io/badge/Huggingface-FFD21E?style=flat&logo=Huggingface&logoColor=black" alt="Huggingface">
</a>
<a href="https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab">
</a>
<a href="https://arxiv.org/abs/2412.03555">
<img src="https://img.shields.io/badge/Arvix-B31B1B?style=flat&logo=arXiv&logoColor=white" alt="Paper">
</a>
<a href="https://supervision.roboflow.com/">
<img src="https://img.shields.io/badge/Supervision-6706CE?style=flat&logo=Roboflow&logoColor=white" alt="Supervision">
</a>
</div>
\n\n
PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343)
vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question
answering, text reading, object detection and object segmentation.
This space show how to use PaliGemma 2 for object detection with supervision.
You can input an image and a text prompt
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
input_text = gr.Textbox(lines=2, placeholder="Enter text here...", label="Enter prompt for example 'detect person;dog")
class_names = gr.Textbox(lines=1, placeholder="Enter class names separated by commas...", label="Class Names")
with gr.Column():
annotated_image = gr.Image(type="pil", label="Annotated Image")
detection_result = gr.Textbox(label="Detection Result")
gr.Button("Submit").click(
fn=process_image,
inputs=[input_image, input_text, class_names],
outputs=[annotated_image, detection_result]
)
if __name__ == "__main__":
app.launch() |