MWilinski's picture
feat: working demo
755362b
raw
history blame
1.43 kB
import gradio as gr
import cv2
from PIL import Image
import numpy as np
from gradio import components
import torchvision
from torchvision.models.detection import (
maskrcnn_resnet50_fpn,
MaskRCNN_ResNet50_FPN_Weights,
)
import torchvision.transforms.functional as F
import torch
from torchvision.utils import draw_segmentation_masks
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()
model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
model = model.eval()
def segment_and_show(image):
input_image = Image.fromarray(image)
input_tensor = torch.tensor(np.array(input_image))
input_tensor = input_tensor.permute(2, 0, 1)
input_image = transforms(input_image)
output = model([input_image])[0]
proba_threshold = 0.5
masks = output["masks"] > proba_threshold
masks = masks.squeeze(1)
image_with_segmasks = draw_segmentation_masks(input_tensor, masks, alpha=0.7)
return np.array(F.to_pil_image(image_with_segmasks))
default_image = Image.open("demo.jpeg")
iface = gr.Interface(
fn=segment_and_show,
inputs=components.Image(value=default_image, sources=["upload", "clipboard"]),
outputs=components.Image(type="pil"),
title="Urban Autonomy Instance Segmentation Demo",
description="Upload an image or use the default to see the instance segmentation model in action.",
)
if __name__ == "__main__":
iface.launch()