File size: 2,319 Bytes
231edce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import cv2
import glob
import gradio as gr
import mediapy
import nibabel
import numpy as np
import shutil
import torch
import torch.nn.functional as F

from omegaconf import OmegaConf
from skp import builder


def window(x, WL=400, WW=2500):
    lower, upper = WL - WW // 2, WL + WW // 2
    x = np.clip(x, lower, upper)
    x = x - lower
    x = x / (upper - lower)
    return (x * 255).astype("uint8")


def rescale(x):
    x = x / 255.
    x = x - 0.5
    x = x * 2.0
    return x


def generate_segmentation_video(study):
    img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
    img = window(img)

    X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
    X = F.interpolate(X, size=(192, 192, 192), mode="nearest")
    X = rescale(X)
    with torch.no_grad():
        seg_output = seg_model(X)

    seg_output = torch.sigmoid(seg_output)
    p_spine = seg_output[:, :7].sum(1)
    seg_output = torch.argmax(seg_output, dim=1) + 1
    seg_output[p_spine < 0.5] = 0
    seg_output = F.interpolate(seg_output.unsqueeze(0).float(), size=img.shape, mode="nearest")
    seg_output = seg_output.squeeze(0).squeeze(0).numpy()
    seg_output = (seg_output * 255 / 7).astype("uint8")
    seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])

    frames = []
    skip = 8
    for idx in range(0, img.shape[2], skip):
        i = img[:, :, idx]
        o = seg_output[:, :, idx]
        i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
        frame = np.concatenate((i, o), 1)
        frames.append(frame)
    mediapy.write_video("video.mp4", frames, fps=30)
    return "video.mp4"


ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)

config = OmegaConf.load("configs/pseudoseg000.yaml")
config.model.load_pretrained = "seg.ckpt"
seg_model = builder.build_model(config).eval()
examples = glob.glob("examples/*.nii.gz")

with gr.Blocks(theme="dark-peach") as demo:
    select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
    button_predict = gr.Button("Predict")
    video_output = gr.Video()
    button_predict.click(fn=generate_segmentation_video,
                         inputs=select_study,
                         outputs=video_output)


if __name__ == "__main__":
    demo.launch(debug=True, share=True)