aman5614's picture
Create app.py
a882a86 verified
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image,ImageDraw
import requests
import gradio as gr
from gtts import gTTS
import random
from collections import Counter
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
# Load model and processor
model_name = "facebook/detr-resnet-50"
processor = DetrImageProcessor.from_pretrained(model_name)
model = DetrForObjectDetection.from_pretrained(model_name)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Function to generate random colors
def random_color():
return "#{:02x}{:02x}{:02x}".format(random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
# Object detection function
def detect_objects(image):
# Resize image for better detection
image = image.resize((800, 800))
# Process image
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Extract bounding boxes and labels
target_sizes = [image.size[::-1]]
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
# Apply confidence threshold
keep = results["scores"] > 0.5
boxes = results["boxes"][keep]
labels = results["labels"][keep]
# Create a copy of the image
image_draw = image.copy()
draw = ImageDraw.Draw(image_draw)
label_counts = Counter()
colors = {}
# Draw bounding boxes and count labels
for box, label in zip(boxes, labels):
box = [int(i) for i in box.tolist()]
label_text = model.config.id2label[label.item()]
label_counts[label_text] += 1 # Count occurrences
if label_text not in colors:
colors[label_text] = random_color()
draw.rectangle(box, outline=colors[label_text], width=5)
# Prepare HTML output for labels
styled_labels = [
f"<span style='background-color:{colors[label]}; color:white; padding:8px 15px; border-radius:10px; margin-right:10px;'>"
f"{label} (x{count})</span>"
for label, count in label_counts.items()
]
labels_html = "<div style='display:flex; flex-wrap:wrap; gap:10px;'>" + " ".join(styled_labels) + "</div>"
# Convert detected objects into speech
detected_objects = ", ".join([f"{label} ({count} times)" for label, count in label_counts.items()])
description = f"I detected the following objects: {detected_objects}." if detected_objects else "No objects detected, please try another image."
# Save audio
audio_path = "detected_objects.mp3"
tts = gTTS(description)
tts.save(audio_path)
return image_draw, labels_html, audio_path
# Gradio Interface
interface = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=[
gr.Image(label="Detected Objects"),
gr.HTML(label="Detected Labels"),
gr.Audio(label="Audio Description")
],
title="AI Assistant for Visually Impaired",
description="This app detects objects in an image and provides an audio description."
)
# Launch
interface.launch()