Amazingldl commited on
Commit
aede3d3
·
1 Parent(s): be18920

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from typing import List
5
+ from PIL import Image, ImageDraw
6
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
7
+
8
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
9
+ model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
10
+
11
+
12
+ def pro_process(labelstring):
13
+ labels = labelstring.split(",")
14
+ labels = [i.strip() for i in labels]
15
+ return labels
16
+
17
+
18
+ def inference(img: Image.Image, labels: List[str]) -> Image.Image:
19
+ labels = pro_process(labels)
20
+ print(labels)
21
+ inputs = processor(text=labels, images=img, return_tensors="pt")
22
+ outputs = model(**inputs)
23
+ target_sizes = torch.Tensor([img.size[::-1]])
24
+ results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)
25
+ i = 0
26
+ boxes, scores, labels_index = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
27
+ draw = ImageDraw.Draw(img)
28
+ for box, score, label_index in zip(boxes, scores, labels_index):
29
+ box = [round(i, 2) for i in box.tolist()]
30
+ xmin, ymin, xmax, ymax = box
31
+ draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
32
+ draw.text((xmin, ymin), f"{labels[label_index]}: {round(float(score),2)}", fill="white")
33
+ return img
34
+
35
+
36
+ with gr.Blocks(title="Zero-shot object detection", theme="freddyaboulton/dracula_revamped") as demo:
37
+ gr.Markdown(""
38
+ "## Zero-shot object detection"
39
+ "")
40
+ with gr.Row():
41
+ with gr.Column():
42
+ in_img = gr.Image(label="Input Image", type="pil")
43
+ in_labels = gr.Textbox(label="Input labels, comma apart")
44
+ inference_btn = gr.Button("Inference", variant="primary")
45
+ with gr.Column():
46
+ out_img = gr.Image(label="Result", interactive=False)
47
+
48
+ inference_btn.click(inference, inputs=[in_img, in_labels], outputs=[out_img])
49
+
50
+ if __name__ == "__main__":
51
+ demo.queue().launch(server_name="127.0.0.1")