File size: 4,154 Bytes
a882a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image,ImageDraw
import requests
import gradio as gr
from gtts import gTTS
import random
from collections import Counter


url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")

inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)

# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    print(
            f"Detected {model.config.id2label[label.item()]} with confidence "
            f"{round(score.item(), 3)} at location {box}"
    )


# Load model and processor
model_name = "facebook/detr-resnet-50"
processor = DetrImageProcessor.from_pretrained(model_name)
model = DetrForObjectDetection.from_pretrained(model_name)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Function to generate random colors
def random_color():
    return "#{:02x}{:02x}{:02x}".format(random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))

# Object detection function
def detect_objects(image):
    # Resize image for better detection
    image = image.resize((800, 800))

    # Process image
    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    # Extract bounding boxes and labels
    target_sizes = [image.size[::-1]]
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]

    # Apply confidence threshold
    keep = results["scores"] > 0.5
    boxes = results["boxes"][keep]
    labels = results["labels"][keep]

    # Create a copy of the image
    image_draw = image.copy()
    draw = ImageDraw.Draw(image_draw)

    label_counts = Counter()
    colors = {}

    # Draw bounding boxes and count labels
    for box, label in zip(boxes, labels):
        box = [int(i) for i in box.tolist()]
        label_text = model.config.id2label[label.item()]

        label_counts[label_text] += 1  # Count occurrences

        if label_text not in colors:
            colors[label_text] = random_color()

        draw.rectangle(box, outline=colors[label_text], width=5)

    # Prepare HTML output for labels
    styled_labels = [
        f"<span style='background-color:{colors[label]}; color:white; padding:8px 15px; border-radius:10px; margin-right:10px;'>"
        f"{label} (x{count})</span>"
        for label, count in label_counts.items()
    ]

    labels_html = "<div style='display:flex; flex-wrap:wrap; gap:10px;'>" + " ".join(styled_labels) + "</div>"

    # Convert detected objects into speech
    detected_objects = ", ".join([f"{label} ({count} times)" for label, count in label_counts.items()])
    description = f"I detected the following objects: {detected_objects}." if detected_objects else "No objects detected, please try another image."

    # Save audio
    audio_path = "detected_objects.mp3"
    tts = gTTS(description)
    tts.save(audio_path)

    return image_draw, labels_html, audio_path

# Gradio Interface
interface = gr.Interface(
    fn=detect_objects,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=[
        gr.Image(label="Detected Objects"),
        gr.HTML(label="Detected Labels"),
        gr.Audio(label="Audio Description")
    ],
    title="AI Assistant for Visually Impaired",
    description="This app detects objects in an image and provides an audio description."
)

# Launch
interface.launch()