from pathlib import Path import torch import torch.nn as nn import torchio as tio import numpy as np from tqdm.notebook import tqdm import gradio as gr from matplotlib import pyplot as plt torch.set_grad_enabled(False); # Download an example image import urllib url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png") try: urllib.URLopener().retrieve(url, filename) except: urllib.request.urlretrieve(url, filename) def inference(img): path = img.name slices = [tio.ScalarImage(path).data] tensor = torch.cat(slices, dim=-1) guessed_affine = np.diag([-1, -1, 9, 1]) subject = tio.Subject(mri=tio.ScalarImage(tensor=tensor, affine=guessed_affine)) subject_preprocessed = tio.ZNormalization()(subject) subject_preprocessed.plot() subject_preprocessed.mri patch_overlap = 0 patch_size = 256, 256, 1 grid_sampler = tio.inference.GridSampler( subject_preprocessed, patch_size, patch_overlap, ) patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=8) aggregator = tio.inference.GridAggregator(grid_sampler) model = torch.hub.load( 'mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=True, ) for patches_batch in tqdm(patch_loader): input_tensor = patches_batch['mri'][tio.DATA][..., 0] locations = patches_batch[tio.LOCATION] probs = model(input_tensor)[..., np.newaxis] aggregator.add_batch(probs, locations) output_tensor = aggregator.get_output_tensor() output_subject = tio.Subject(prediction=tio.ScalarImage(tensor=output_tensor, affine=guessed_affine)) images = subject_preprocessed.mri.tensor.detach().numpy().reshape((3, 256, 256)) mask = output_subject.prediction.tensor.detach().numpy().reshape((256, 256)) images = np.moveaxis(np.moveaxis(images, 0, 2), 0, 1) mask = np.moveaxis(mask, 0, 1) f, ax = plt.subplots(1, 2) ax[0].set_axis_off() ax[1].set_axis_off() ax[0].imshow(images) ax[1].imshow(mask, cmap='gray') return f title = "U-NET FOR BRAIN MRI" description = """
Segmentation of brain tumor in magnetic resonance images | Github Repo
" examples = [ ['TCGA_CS_4944.png'] ] gr.Interface(inference, gr.inputs.Image(label="input image", type='file'), gr.outputs.Image(type='plot'), description=description, article=article, title=title, examples=examples, analytics_enabled=False).launch()