aman5614 commited on
Commit
a882a86
·
verified ·
1 Parent(s): 6acf8ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
+ from PIL import Image,ImageDraw
4
+ import requests
5
+ import gradio as gr
6
+ from gtts import gTTS
7
+ import random
8
+ from collections import Counter
9
+
10
+
11
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
12
+ image = Image.open(requests.get(url, stream=True).raw)
13
+
14
+ # you can specify the revision tag if you don't want the timm dependency
15
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
16
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
17
+
18
+ inputs = processor(images=image, return_tensors="pt")
19
+ outputs = model(**inputs)
20
+
21
+ # convert outputs (bounding boxes and class logits) to COCO API
22
+ # let's only keep detections with score > 0.9
23
+ target_sizes = torch.tensor([image.size[::-1]])
24
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
25
+
26
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
27
+ box = [round(i, 2) for i in box.tolist()]
28
+ print(
29
+ f"Detected {model.config.id2label[label.item()]} with confidence "
30
+ f"{round(score.item(), 3)} at location {box}"
31
+ )
32
+
33
+
34
+ # Load model and processor
35
+ model_name = "facebook/detr-resnet-50"
36
+ processor = DetrImageProcessor.from_pretrained(model_name)
37
+ model = DetrForObjectDetection.from_pretrained(model_name)
38
+
39
+ # Move model to GPU if available
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ model.to(device)
42
+
43
+ # Function to generate random colors
44
+ def random_color():
45
+ return "#{:02x}{:02x}{:02x}".format(random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
46
+
47
+ # Object detection function
48
+ def detect_objects(image):
49
+ # Resize image for better detection
50
+ image = image.resize((800, 800))
51
+
52
+ # Process image
53
+ inputs = processor(images=image, return_tensors="pt").to(device)
54
+
55
+ with torch.no_grad():
56
+ outputs = model(**inputs)
57
+
58
+ # Extract bounding boxes and labels
59
+ target_sizes = [image.size[::-1]]
60
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
61
+
62
+ # Apply confidence threshold
63
+ keep = results["scores"] > 0.5
64
+ boxes = results["boxes"][keep]
65
+ labels = results["labels"][keep]
66
+
67
+ # Create a copy of the image
68
+ image_draw = image.copy()
69
+ draw = ImageDraw.Draw(image_draw)
70
+
71
+ label_counts = Counter()
72
+ colors = {}
73
+
74
+ # Draw bounding boxes and count labels
75
+ for box, label in zip(boxes, labels):
76
+ box = [int(i) for i in box.tolist()]
77
+ label_text = model.config.id2label[label.item()]
78
+
79
+ label_counts[label_text] += 1 # Count occurrences
80
+
81
+ if label_text not in colors:
82
+ colors[label_text] = random_color()
83
+
84
+ draw.rectangle(box, outline=colors[label_text], width=5)
85
+
86
+ # Prepare HTML output for labels
87
+ styled_labels = [
88
+ f"<span style='background-color:{colors[label]}; color:white; padding:8px 15px; border-radius:10px; margin-right:10px;'>"
89
+ f"{label} (x{count})</span>"
90
+ for label, count in label_counts.items()
91
+ ]
92
+
93
+ labels_html = "<div style='display:flex; flex-wrap:wrap; gap:10px;'>" + " ".join(styled_labels) + "</div>"
94
+
95
+ # Convert detected objects into speech
96
+ detected_objects = ", ".join([f"{label} ({count} times)" for label, count in label_counts.items()])
97
+ description = f"I detected the following objects: {detected_objects}." if detected_objects else "No objects detected, please try another image."
98
+
99
+ # Save audio
100
+ audio_path = "detected_objects.mp3"
101
+ tts = gTTS(description)
102
+ tts.save(audio_path)
103
+
104
+ return image_draw, labels_html, audio_path
105
+
106
+ # Gradio Interface
107
+ interface = gr.Interface(
108
+ fn=detect_objects,
109
+ inputs=gr.Image(type="pil", label="Upload an Image"),
110
+ outputs=[
111
+ gr.Image(label="Detected Objects"),
112
+ gr.HTML(label="Detected Labels"),
113
+ gr.Audio(label="Audio Description")
114
+ ],
115
+ title="AI Assistant for Visually Impaired",
116
+ description="This app detects objects in an image and provides an audio description."
117
+ )
118
+
119
+ # Launch
120
+ interface.launch()