masdar's picture
upload app.py
f6a04da
raw
history blame
3.6 kB
import os
from os.path import splitext
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torchvision
import wget
destination_folder = "output"
destination_for_weights = "weights"
if os.path.exists(destination_for_weights):
print("The weights are at", destination_for_weights)
else:
print("Creating folder at ", destination_for_weights, " to store weights")
os.mkdir(destination_for_weights)
segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'
if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
print("Downloading Segmentation Weights, ", segmentationWeightsURL," to ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
else:
print("Segmentation Weights already present")
torch.cuda.empty_cache()
def collate_fn(x):
x, f = zip(*x)
i = list(map(lambda t: t.shape[1], x))
x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
return x, f, i
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
print("loading weights from ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))
if torch.cuda.is_available():
print("cuda is available, original weights")
device = torch.device("cuda")
model = torch.nn.DataParallel(model)
model.to(device)
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
model.load_state_dict(checkpoint['state_dict'])
else:
print("cuda is not available, cpu weights")
device = torch.device("cpu")
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict_cpu)
model.eval()
def segment(input):
inp = input
x = inp.transpose([2, 0, 1])
x = np.expand_dims(x, axis=0)
mean = x.mean(axis=(0, 2, 3))
std = x.std(axis=(0, 2, 3))
x = x - mean.reshape(1, 3, 1, 1)
x = x / std.reshape(1, 3, 1, 1)
with torch.no_grad():
x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
output = model(x)
y = output['out'].numpy()
y = y.squeeze()
out = y>0
mask = inp.copy()
mask[out] = np.array([0, 0, 255])
return mask
import gradio as gr
i = gr.inputs.Image(shape=(112, 112), label="Echocardiogram")
o = gr.outputs.Image(label="Segmentation Mask")
examples = [["TCGA_CS_5395_19981004_12.png"],
["TCGA_CS_5395_19981004_14.png"],
["TCGA_DU_5849_19950405_24.png"]]
title = "Sistem Segmentasi Citra MRI Otak berbasis Artificial Intelligence"
description = "This system is designed to help automate the process of accurately and efficiently segmenting brain MRIs into regions of interest. It does this by using a UBNet-Seg Architecture that has been trained on a large dataset of manually annotated brain images."
#thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
gr.Interface(segment, i, o,
allow_flagging = False,
description = description,
title=title,
examples = examples,
analytics_enabled = False).launch()