Amazingldl's picture
Create app.py
aede3d3
raw
history blame
2.01 kB
import gradio as gr
import torch
import numpy as np
from typing import List
from PIL import Image, ImageDraw
from transformers import OwlViTProcessor, OwlViTForObjectDetection
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
def pro_process(labelstring):
labels = labelstring.split(",")
labels = [i.strip() for i in labels]
return labels
def inference(img: Image.Image, labels: List[str]) -> Image.Image:
labels = pro_process(labels)
print(labels)
inputs = processor(text=labels, images=img, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.Tensor([img.size[::-1]])
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)
i = 0
boxes, scores, labels_index = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
draw = ImageDraw.Draw(img)
for box, score, label_index in zip(boxes, scores, labels_index):
box = [round(i, 2) for i in box.tolist()]
xmin, ymin, xmax, ymax = box
draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
draw.text((xmin, ymin), f"{labels[label_index]}: {round(float(score),2)}", fill="white")
return img
with gr.Blocks(title="Zero-shot object detection", theme="freddyaboulton/dracula_revamped") as demo:
gr.Markdown(""
"## Zero-shot object detection"
"")
with gr.Row():
with gr.Column():
in_img = gr.Image(label="Input Image", type="pil")
in_labels = gr.Textbox(label="Input labels, comma apart")
inference_btn = gr.Button("Inference", variant="primary")
with gr.Column():
out_img = gr.Image(label="Result", interactive=False)
inference_btn.click(inference, inputs=[in_img, in_labels], outputs=[out_img])
if __name__ == "__main__":
demo.queue().launch(server_name="127.0.0.1")