File size: 3,024 Bytes
fb993be
 
 
 
 
 
 
 
 
 
 
 
 
 
b721d57
fb993be
b721d57
 
fb993be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b721d57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from pathlib import Path
import torch
import torch.nn as nn
import torchio as tio
import numpy as np
from tqdm.notebook import tqdm
import gradio as gr
from matplotlib import pyplot as plt

torch.set_grad_enabled(False);
# Download an example image
import urllib
url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png")
try: urllib.URLopener().retrieve(url, filename)
except Exception: urllib.request.urlretrieve(url, filename)
def inference(img):

    path = img
    slices = [tio.ScalarImage(path).data]
    tensor = torch.cat(slices, dim=-1)
    guessed_affine = np.diag([-1, -1, 9, 1])
    subject = tio.Subject(mri=tio.ScalarImage(tensor=tensor, affine=guessed_affine))
    subject_preprocessed = tio.ZNormalization()(subject)
    subject_preprocessed.plot()
    subject_preprocessed.mri
    patch_overlap = 0
    patch_size = 256, 256, 1
    grid_sampler = tio.inference.GridSampler(
        subject_preprocessed,
        patch_size,
        patch_overlap,
    )
    patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=8)
    aggregator = tio.inference.GridAggregator(grid_sampler)
    model = torch.hub.load(
        'mateuszbuda/brain-segmentation-pytorch',
        'unet',
        in_channels=3,
        out_channels=1,
        init_features=32,
        pretrained=True,
    )
    for patches_batch in tqdm(patch_loader):
        input_tensor = patches_batch['mri'][tio.DATA][..., 0]
        locations = patches_batch[tio.LOCATION]
        probs = model(input_tensor)[..., np.newaxis]
        aggregator.add_batch(probs, locations)
    output_tensor = aggregator.get_output_tensor()
    output_subject = tio.Subject(prediction=tio.ScalarImage(tensor=output_tensor, affine=guessed_affine))
    images = subject_preprocessed.mri.tensor.detach().numpy().reshape((3, 256, 256))
    mask = output_subject.prediction.tensor.detach().numpy().reshape((256, 256))
    images = np.moveaxis(np.moveaxis(images, 0, 2), 0, 1)
    mask = np.moveaxis(mask, 0, 1)

    f, ax = plt.subplots(1, 2)
    ax[0].set_axis_off()
    ax[1].set_axis_off()
    ax[0].imshow(images)
    ax[1].imshow(mask, cmap='gray')
    return f

title = "U-NET FOR BRAIN MRI"
description = "Gradio demo for u-net for brain mri, U-Net with batch normalization for biomedical image segmentation with pretrained weights for abnormality segmentation in brain MRI. To use it, simply add your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://mateuszbuda.github.io/2017/12/01/brainseg.html'>Segmentation of brain tumor in magnetic resonance images</a> | <a href='https://github.com/mateuszbuda/brain-segmentation-pytorch'>Github Repo</a></p>"
examples = [
            ['TCGA_CS_4944.png']
]
gr.Interface(inference, gr.Image(label="input image", type='filepath'), gr.Plot(), description=description, article=article, title=title, examples=examples, analytics_enabled=False).launch()