ianpan's picture
Initial commit
231edce
raw
history blame
2.32 kB
import cv2
import glob
import gradio as gr
import mediapy
import nibabel
import numpy as np
import shutil
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from skp import builder
def window(x, WL=400, WW=2500):
lower, upper = WL - WW // 2, WL + WW // 2
x = np.clip(x, lower, upper)
x = x - lower
x = x / (upper - lower)
return (x * 255).astype("uint8")
def rescale(x):
x = x / 255.
x = x - 0.5
x = x * 2.0
return x
def generate_segmentation_video(study):
img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
img = window(img)
X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
X = F.interpolate(X, size=(192, 192, 192), mode="nearest")
X = rescale(X)
with torch.no_grad():
seg_output = seg_model(X)
seg_output = torch.sigmoid(seg_output)
p_spine = seg_output[:, :7].sum(1)
seg_output = torch.argmax(seg_output, dim=1) + 1
seg_output[p_spine < 0.5] = 0
seg_output = F.interpolate(seg_output.unsqueeze(0).float(), size=img.shape, mode="nearest")
seg_output = seg_output.squeeze(0).squeeze(0).numpy()
seg_output = (seg_output * 255 / 7).astype("uint8")
seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
frames = []
skip = 8
for idx in range(0, img.shape[2], skip):
i = img[:, :, idx]
o = seg_output[:, :, idx]
i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
frame = np.concatenate((i, o), 1)
frames.append(frame)
mediapy.write_video("video.mp4", frames, fps=30)
return "video.mp4"
ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)
config = OmegaConf.load("configs/pseudoseg000.yaml")
config.model.load_pretrained = "seg.ckpt"
seg_model = builder.build_model(config).eval()
examples = glob.glob("examples/*.nii.gz")
with gr.Blocks(theme="dark-peach") as demo:
select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
button_predict = gr.Button("Predict")
video_output = gr.Video()
button_predict.click(fn=generate_segmentation_video,
inputs=select_study,
outputs=video_output)
if __name__ == "__main__":
demo.launch(debug=True, share=True)