File size: 3,781 Bytes
84bb4ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6a04da
 
84bb4ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b29d3e5
 
9d3129c
 
84bb4ea
8133a9d
 
2d4378c
 
 
8133a9d
449c97f
 
84bb4ea
451eff0
 
b29d3e5
8133a9d
 
 
451eff0
 
8133a9d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from os.path import splitext
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torchvision
import wget 


destination_folder = "output"
destination_for_weights = "weights"

if os.path.exists(destination_for_weights):
    print("The weights are at", destination_for_weights)
else:
    print("Creating folder at ", destination_for_weights, " to store weights")
    os.mkdir(destination_for_weights)
    
segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'

if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
    print("Downloading Segmentation Weights, ", segmentationWeightsURL," to ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
    filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
else:
    print("Segmentation Weights already present")

torch.cuda.empty_cache()

def collate_fn(x):
    x, f = zip(*x)
    i = list(map(lambda t: t.shape[1], x))
    x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
    return x, f, i

model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)

print("loading weights from ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))

if torch.cuda.is_available():
    print("cuda is available, original weights")
    device = torch.device("cuda")
    model = torch.nn.DataParallel(model)
    model.to(device)
    checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
    model.load_state_dict(checkpoint['state_dict'])
else:
    print("cuda is not available, cpu weights")
    device = torch.device("cpu")
    checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
    state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict_cpu)

model.eval()

def segment(input):
    inp = input
    x = inp.transpose([2, 0, 1]) 
    x = np.expand_dims(x, axis=0)     
    
    mean = x.mean(axis=(0, 2, 3))
    std = x.std(axis=(0, 2, 3))
    x = x - mean.reshape(1, 3, 1, 1)
    x = x / std.reshape(1, 3, 1, 1)
    
    with torch.no_grad():
        x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
        output = model(x)    
    
    y = output['out'].numpy()
    y = y.squeeze()
    
    out = y>0    
    
    mask = inp.copy()
    mask[out] = np.array([0, 0, 255])
    
    return mask

import gradio as gr

i = gr.inputs.Image(shape=(112, 112), label="Input Brain MRI")
o = gr.outputs.Image(label="Hasil Segmentasi")

examples = [["TCGA_CS_5395_19981004_12.png"], 
            ["TCGA_CS_5395_19981004_14.png"],
            ["TCGA_DU_5849_19950405_20.png"],
            ["TCGA_DU_5849_19950405_24.png"],
            ["TCGA_DU_5849_19950405_28.png"]]

title = "Sistem Segmentasi Citra MRI Otak berbasis Artificial Intelligence"
description = "This system is designed to help automate the process of accurately and efficiently segmenting brain MRIs into regions of interest. It does this by using a UBNet-Seg Architecture that has been trained on a large dataset of manually annotated brain images."

article = "<p style='text-align: center'>Created by <a target='_blank' href='https://fi.ub.ac.id/'>Jurusan Fisika, FMIPA, Universitas Brawijaya </a></p>"


gr.Interface(segment, i, o, 
    allow_flagging = False, 
    description = description,
    title = title,
    article = article,
    examples = examples,
    analytics_enabled = False).launch()