Spaces:
Runtime error
Runtime error
File size: 2,359 Bytes
76f797b 00a9f3d 76f797b 00a9f3d 76f797b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import numpy as np
import codecs
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
from unetplusplus import NestedUNet
torch.manual_seed(0)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
# Device
DEVICE = "cpu"
print(DEVICE)
# Load color map
cmap = np.load('cmap.npy')
# Make directories
os.system("mkdir ./models")
# Get model weights
if not os.path.exists("./models/masksupnyu39.31d.pth"):
os.system("wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth")
# Load model
model = NestedUNet(num_classes=40)
checkpoint = torch.load("./models/masksupnyu39.31d.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
model = model.to(DEVICE)
model.eval()
# Main inference function
def inference(img_path):
image = Image.open(img_path).convert("RGB")
transforms_image = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
image = transforms_image(image)
image = image[None, :]
# Predict
with torch.no_grad():
output = torch.sigmoid(model(image.to(DEVICE).float()))
output = torch.softmax(output, dim=1).argmax(dim=1)[0].float().cpu().numpy().astype(np.uint8)
pred = cmap[output]
return pred
# App
title = "Masked Supervised Learning for Semantic Segmentation"
description = codecs.open("description.html", "r", "utf-8").read()
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>"
gr.Interface(
inference,
gr.inputs.Image(type='filepath', label="Input Image"),
gr.outputs.Image(type="filepath", label="Predicted Output"),
examples=["./sample_images/a.png", "./sample_images/b.png",
"./sample_images/c.png", "./sample_images/d.png"],
title=title,
description=description,
article=article,
allow_flagging=False,
analytics_enabled=False,
).launch(debug=True, enable_queue=True) |