ya02 commited on
Commit
342bbfd
·
verified ·
1 Parent(s): d755122

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
3
+ from PIL import Image, ImageDraw
4
+ import gradio as gr
5
+
6
+ # Specify the checkpoint name or identifier for the pre-trained model
7
+ checkpoint = "google/owlvit-base-patch32"
8
+
9
+ # Initialize the pre-trained model and processor
10
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint)
11
+ processor = AutoProcessor.from_pretrained(checkpoint)
12
+
13
+ def detect_objects(image, text_queries):
14
+ # Convert image to PIL Image format if not already
15
+ if isinstance(image, str):
16
+ image = Image.open(image)
17
+
18
+ # Prepare inputs for zero-shot object detection
19
+ inputs = processor(images=image, text=text_queries, return_tensors="pt")
20
+
21
+ # Perform inference with the model
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+ target_sizes = torch.tensor([image.size[::-1]])
25
+ results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
26
+
27
+ # Create a drawing object for the image
28
+ draw = ImageDraw.Draw(image)
29
+
30
+ # Extract detection results (scores, labels, and bounding boxes)
31
+ scores = results["scores"].tolist()
32
+ labels = results["labels"].tolist()
33
+ boxes = results["boxes"].tolist()
34
+
35
+ # Iterate over detected objects and draw bounding boxes and labels
36
+ for box, score, label in zip(boxes, scores, labels):
37
+ xmin, ymin, xmax, ymax = box
38
+ draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
39
+ draw.text((xmin, ymin), f"{text_queries[label]}: {round(score, 2)}", fill="black")
40
+
41
+ return image
42
+
43
+ # Gradio Interface
44
+ gr.Interface(
45
+ fn=detect_objects,
46
+ inputs=[
47
+ gr.Image(type="pil", label="Upload an Image"),
48
+ gr.Textbox(lines=2, placeholder="Enter text queries separated by commas...", label="Text Queries")
49
+ ],
50
+ outputs=gr.Image(label="Detected Objects"),
51
+ title="AI Workshop Zero-Shot Object Detection",
52
+ description="Upload an image and provide text queries to perform zero-shot object detection using a pre-trained model. The model identifies objects based on the queries you provide.",
53
+ ).launch()