import gradio as gr from PIL import Image import numpy as np import torch from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights from torchvision.utils import draw_bounding_boxes import matplotlib.pyplot as plt from io import BytesIO weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT categories = weights.meta["categories"] img_preprocess = weights.transforms() def load_model(): model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.5) model.eval() return model model = load_model() def make_prediction(img): img_processed = img_preprocess(img) prediction = model(img_processed.unsqueeze(0)) prediction = prediction[0] prediction["labels"] = [categories[label] for label in prediction["labels"]] return prediction def create_image_with_bboxes(img, prediction): img_tensor = torch.tensor(img) img_with_bboxes = draw_bounding_boxes(img_tensor, boxes=prediction["boxes"], labels=prediction["labels"], colors=["red" if label=="person" else "green" for label in prediction["labels"]], width=2) img_with_bboxes_np = img_with_bboxes.detach().numpy().transpose(1,2,0) return img_with_bboxes_np def process_image(image): img = Image.fromarray(image.astype('uint8'), 'RGB') prediction = make_prediction(img) img_with_bbox = create_image_with_bboxes(np.array(img).transpose(2,0,1), prediction) fig = plt.figure(figsize=(12,12)) ax = fig.add_subplot(111) plt.imshow(img_with_bbox) plt.xticks([],[]) plt.yticks([],[]) ax.spines[["top", "bottom", "right", "left"]].set_visible(False) plt.tight_layout() plt.close(fig) # Save plot to a BytesIO object img_bytes = BytesIO() fig.savefig(img_bytes, format='png') img_bytes.seek(0) # Create a summary of detected objects detected_objects = [] for label, score in zip(prediction["labels"], prediction["scores"]): detected_objects.append(f"{label}: {score:.2f}") prediction_data = {k: (v.tolist() if isinstance(v, torch.Tensor) else v) for k, v in prediction.items()} return Image.open(img_bytes), detected_objects, prediction_data gr.Interface( fn=process_image, inputs=gr.Image(type="numpy"), outputs=[gr.Image(type="pil"), gr.Textbox(), gr.JSON()], title="OBJECT_DETECTOR_254", description="Upload an image to detect objects and display bounding boxes along with a summary of detected objects.", ).launch()