ANON-STUDIOS-254's picture
Create app.py
ff464a5 verified
import gradio as gr
from PIL import Image
import numpy as np
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
import matplotlib.pyplot as plt
from io import BytesIO
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
categories = weights.meta["categories"]
img_preprocess = weights.transforms()
def load_model():
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.5)
model.eval()
return model
model = load_model()
def make_prediction(img):
img_processed = img_preprocess(img)
prediction = model(img_processed.unsqueeze(0))
prediction = prediction[0]
prediction["labels"] = [categories[label] for label in prediction["labels"]]
return prediction
def create_image_with_bboxes(img, prediction):
img_tensor = torch.tensor(img)
img_with_bboxes = draw_bounding_boxes(img_tensor, boxes=prediction["boxes"], labels=prediction["labels"],
colors=["red" if label=="person" else "green" for label in prediction["labels"]], width=2)
img_with_bboxes_np = img_with_bboxes.detach().numpy().transpose(1,2,0)
return img_with_bboxes_np
def process_image(image):
img = Image.fromarray(image.astype('uint8'), 'RGB')
prediction = make_prediction(img)
img_with_bbox = create_image_with_bboxes(np.array(img).transpose(2,0,1), prediction)
fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(111)
plt.imshow(img_with_bbox)
plt.xticks([],[])
plt.yticks([],[])
ax.spines[["top", "bottom", "right", "left"]].set_visible(False)
plt.tight_layout()
plt.close(fig)
# Save plot to a BytesIO object
img_bytes = BytesIO()
fig.savefig(img_bytes, format='png')
img_bytes.seek(0)
# Create a summary of detected objects
detected_objects = []
for label, score in zip(prediction["labels"], prediction["scores"]):
detected_objects.append(f"{label}: {score:.2f}")
prediction_data = {k: (v.tolist() if isinstance(v, torch.Tensor) else v) for k, v in prediction.items()}
return Image.open(img_bytes), detected_objects, prediction_data
gr.Interface(
fn=process_image,
inputs=gr.Image(type="numpy"),
outputs=[gr.Image(type="pil"), gr.Textbox(), gr.JSON()],
title="OBJECT_DETECTOR_254",
description="Upload an image to detect objects and display bounding boxes along with a summary of detected objects.",
).launch()