antoniorached commited on
Commit
93bb23b
·
verified ·
1 Parent(s): bb51e1a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
3
+ from PIL import Image, ImageDraw
4
+ import gradio as gr
5
+
6
+ checkpoint = "google/owlvit-base-patch32"
7
+
8
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint)
9
+ processor = AutoProcessor.from_pretrained(checkpoint)
10
+
11
+ def detect_objects(image, text_queries):
12
+
13
+ if isinstance(image, str):
14
+ image = Image.open(image)
15
+
16
+ inputs = processor(images=image, text=text_queries, return_tensors="pt")
17
+
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+ target_sizes = torch.tensor([image.size[::-1]])
21
+ results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
22
+
23
+ draw = ImageDraw.Draw(image)
24
+
25
+ scores = results["scores"].tolist()
26
+ labels = results["labels"].tolist()
27
+ boxes = results["boxes"].tolist()
28
+
29
+ for box, score, label in zip(boxes, scores, labels):
30
+ xmin, ymin, xmax, ymax = box
31
+ draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
32
+ draw.text((xmin, ymin), f"{text_queries[label]}: {round(score, 2)}", fill="black")
33
+
34
+ return image
35
+
36
+ inputs = [
37
+ gr.Image(type="pil", label="Input Image"),
38
+ gr.Textbox(label="Text Queries (comma-separated)")
39
+ ]
40
+
41
+ output = gr.Image(type="pil", label="Output Image")
42
+
43
+ gr.Interface(
44
+ fn=detect_objects,
45
+ inputs=inputs,
46
+ outputs=output,
47
+ title="Zero-Shot Object Detection",
48
+ description="Detect objects in an image using zero-shot object detection.",
49
+ ).launch()