Last commit not found
raw
history blame
1.66 kB
import cv2
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import gradio as gr
import numpy as np
# Function for DETR object detection
def inf(_, webcam_image):
# Initialize model and processor
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
yellow = (0, 255, 255) # in BGR
stroke = 2
# Convert the webcam image to the correct format
img = cv2.cvtColor(webcam_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(img)
# Process the image with DETR
inputs = processor(images=pil_image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([pil_image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Draw bounding boxes and labels
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
cv2.rectangle(webcam_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), yellow, stroke)
cv2.putText(webcam_image, model.config.id2label[label.item()], (int(box[0]), int(box[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 1, yellow, stroke, cv2.LINE_AA)
# Return the processed image
return webcam_image
# Gradio interface with webcam support
demo = gr.Interface(
inf,
[
gr.Markdown("## Real-Time Object Detection"),
gr.Image(source="webcam", streaming=True)
],
"image",
live=True
)
demo.launch(server_name="0.0.0.0", share=True)