|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
import torch |
|
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights |
|
from torchvision.utils import draw_bounding_boxes |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
|
|
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT |
|
categories = weights.meta["categories"] |
|
img_preprocess = weights.transforms() |
|
|
|
def load_model(): |
|
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.5) |
|
model.eval() |
|
return model |
|
|
|
model = load_model() |
|
|
|
def make_prediction(img): |
|
img_processed = img_preprocess(img) |
|
prediction = model(img_processed.unsqueeze(0)) |
|
prediction = prediction[0] |
|
prediction["labels"] = [categories[label] for label in prediction["labels"]] |
|
return prediction |
|
|
|
def create_image_with_bboxes(img, prediction): |
|
img_tensor = torch.tensor(img) |
|
img_with_bboxes = draw_bounding_boxes(img_tensor, boxes=prediction["boxes"], labels=prediction["labels"], |
|
colors=["red" if label=="person" else "green" for label in prediction["labels"]], width=2) |
|
img_with_bboxes_np = img_with_bboxes.detach().numpy().transpose(1,2,0) |
|
return img_with_bboxes_np |
|
|
|
def process_image(image): |
|
img = Image.fromarray(image.astype('uint8'), 'RGB') |
|
prediction = make_prediction(img) |
|
img_with_bbox = create_image_with_bboxes(np.array(img).transpose(2,0,1), prediction) |
|
|
|
fig = plt.figure(figsize=(12,12)) |
|
ax = fig.add_subplot(111) |
|
plt.imshow(img_with_bbox) |
|
plt.xticks([],[]) |
|
plt.yticks([],[]) |
|
ax.spines[["top", "bottom", "right", "left"]].set_visible(False) |
|
|
|
plt.tight_layout() |
|
plt.close(fig) |
|
|
|
|
|
img_bytes = BytesIO() |
|
fig.savefig(img_bytes, format='png') |
|
img_bytes.seek(0) |
|
|
|
|
|
detected_objects = [] |
|
for label, score in zip(prediction["labels"], prediction["scores"]): |
|
detected_objects.append(f"{label}: {score:.2f}") |
|
|
|
prediction_data = {k: (v.tolist() if isinstance(v, torch.Tensor) else v) for k, v in prediction.items()} |
|
return Image.open(img_bytes), detected_objects, prediction_data |
|
|
|
gr.Interface( |
|
fn=process_image, |
|
inputs=gr.Image(type="numpy"), |
|
outputs=[gr.Image(type="pil"), gr.Textbox(), gr.JSON()], |
|
title="OBJECT_DETECTOR_254", |
|
description="Upload an image to detect objects and display bounding boxes along with a summary of detected objects.", |
|
).launch() |
|
|