techysanoj's picture
Update app.py
bebfcb3
raw
history blame
1.55 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
from torchvision.models.detection import detr
from PIL import Image
import cv2
import numpy as np
# Load the pretrained DETR model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = detr.DETR(resnet50=True)
model = model.to(device).eval()
# Define the transformation for the input image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((800, 800)),
])
# Define the object detection function
def detect_objects(frame):
# Convert the frame to PIL image
image = Image.fromarray(frame)
# Apply the transformation
image = transform(image).unsqueeze(0).to(device)
# Perform object detection
with torch.no_grad():
outputs = model(image)
# Get the bounding boxes and labels
boxes = outputs['pred_boxes'][0].cpu().numpy()
labels = outputs['pred_classes'][0].cpu().numpy()
# Draw bounding boxes on the frame
for box, label in zip(boxes, labels):
box = [int(coord) for coord in box]
frame = cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
frame = cv2.putText(frame, f'Class: {label}', (box[0], box[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2, cv2.LINE_AA)
return frame
# Define the Gradio interface
iface = gr.Interface(
fn=detect_objects,
inputs=gr.Video(),
outputs="video",
live=True,
capture_session=True,
)
# Launch the Gradio app
iface.launch()