Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from torchvision.models.detection import detr | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
# Load the pretrained DETR model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = detr.DETR(resnet50=True) | |
model = model.to(device).eval() | |
# Define the transformation for the input image | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((800, 800)), | |
]) | |
# Define the object detection function | |
def detect_objects(frame): | |
# Convert the frame to PIL image | |
image = Image.fromarray(frame) | |
# Apply the transformation | |
image = transform(image).unsqueeze(0).to(device) | |
# Perform object detection | |
with torch.no_grad(): | |
outputs = model(image) | |
# Get the bounding boxes and labels | |
boxes = outputs['pred_boxes'][0].cpu().numpy() | |
labels = outputs['pred_classes'][0].cpu().numpy() | |
# Draw bounding boxes on the frame | |
for box, label in zip(boxes, labels): | |
box = [int(coord) for coord in box] | |
frame = cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) | |
frame = cv2.putText(frame, f'Class: {label}', (box[0], box[1] - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2, cv2.LINE_AA) | |
return frame | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=detect_objects, | |
inputs=gr.Video(), | |
outputs="video", | |
live=True, | |
capture_session=True, | |
) | |
# Launch the Gradio app | |
iface.launch() | |