SaladSlayer00's picture
the APP
d7feb62
raw
history blame
1.41 kB
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import torch
import cv2
import numpy as np
def process_image(input_image):
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
yellow = (0, 255, 255) # BGR
font = cv2.FONT_HERSHEY_SIMPLEX
stroke = 2
# Convert PIL image to OpenCV format
img = np.array(input_image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# Process the image
inputs = processor(images=input_image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([input_image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), yellow, stroke)
cv2.putText(img, model.config.id2label[label.item()], (int(box[0]), int(box[1]-10)), font, 1, yellow, stroke, cv2.LINE_AA)
# Convert back to PIL image
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# Create Gradio interface
iface = gr.Interface(fn=process_image, inputs=gr.inputs.Image(), outputs="image")
iface.launch()