Spaces:
Sleeping
Sleeping
Last commit not found
import cv2 | |
import torch | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
# Function for DETR object detection | |
def inf(_, webcam_image): | |
# Initialize model and processor | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
yellow = (0, 255, 255) # in BGR | |
stroke = 2 | |
# Convert the webcam image to the correct format | |
img = cv2.cvtColor(webcam_image, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(img) | |
# Process the image with DETR | |
inputs = processor(images=pil_image, return_tensors="pt") | |
outputs = model(**inputs) | |
target_sizes = torch.tensor([pil_image.size[::-1]]) | |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
# Draw bounding boxes and labels | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
cv2.rectangle(webcam_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), yellow, stroke) | |
cv2.putText(webcam_image, model.config.id2label[label.item()], (int(box[0]), int(box[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 1, yellow, stroke, cv2.LINE_AA) | |
# Return the processed image | |
return webcam_image | |
# Gradio interface with webcam support | |
demo = gr.Interface( | |
inf, | |
[ | |
gr.Markdown("## Real-Time Object Detection"), | |
gr.Image(source="webcam", streaming=True) | |
], | |
"image", | |
live=True | |
) | |
demo.launch(server_name="0.0.0.0", share=True) | |