Harshithtd's picture
Update app.py
bb7f5b6 verified
from typing import List
import os
import numpy as np
import torch
import gradio as gr
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import supervision as sv
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
model = AutoModelForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(device)
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
TRACKER = sv.ByteTrack()
def annotate_image(input_image: np.ndarray, detections, labels: List[str]) -> np.ndarray:
output_image = MASK_ANNOTATOR.annotate(input_image, detections)
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
return output_image
def process_image(input_image: np.ndarray, confidence_threshold: float):
results = query(Image.fromarray(input_image), confidence_threshold)
detections = sv.Detections.from_transformers(results[0])
detections = TRACKER.update_with_detections(detections)
final_labels = [model.config.id2label[label] for label in detections.class_id.tolist()]
output_image = annotate_image(input_image, detections, final_labels)
return output_image, ", ".join(final_labels)
def query(image: Image.Image, confidence_threshold: float):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs=outputs, threshold=confidence_threshold, target_sizes=target_sizes)
return results
def run_demo():
input_image = gr.Image(label="Input Image", type="numpy")
conf = gr.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, value=0.6, step=0.05)
output_image = gr.Image(label="Output Image", type="numpy")
output_text = gr.Textbox(label="Detected Classes")
def process_and_display(input_image, conf):
output_img, detected_classes = process_image(input_image, conf)
return output_img, detected_classes
gr.Interface(
fn=process_and_display,
inputs=[input_image, conf],
outputs=[output_image, output_text],
title="Real Time Object Detection with RT-DETR",
description="This demo uses RT-DETR for object detection in images. Adjust the confidence threshold to see different results.",
).launch()
if __name__ == "__main__":
run_demo()