ianpan's picture
Add fracture inference
b0fbed3
raw
history blame
4.49 kB
import cv2
import glob
import gradio as gr
import mediapy
import nibabel
import numpy as np
import shutil
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from skp import builder
def window(x, WL=400, WW=2500):
lower, upper = WL - WW // 2, WL + WW // 2
x = np.clip(x, lower, upper)
x = x - lower
x = x / (upper - lower)
return (x * 255).astype("uint8")
def rescale(x):
x = x / 255.
x = x - 0.5
x = x * 2.0
return x
def get_cervical_spine_coordinates(x, original_shape):
# Assumes x is torch tensor with shape (1, 8, Z, H, W)
x = x.squeeze(0).numpy()[:7]
rescale_factor = [original_shape[0] / x.shape[1], original_shape[1] / x.shape[2], original_shape[2] / x.shape[3]]
coords_dict = {}
for level in range(x.shape[0]):
coords = np.where(x[level] >= 0.4)
coords = np.vstack(coords).astype("float")
coords[0] = coords[0] * rescale_factor[0]
coords[1] = coords[1] * rescale_factor[1]
coords[2] = coords[2] * rescale_factor[2]
coords = coords.astype("int")
coords_dict[level] = coords[0].min(), coords[0].max(),\
coords[1].min(), coords[1].max(),\
coords[2].min(), coords[2].max()
return coords_dict
def generate_segmentation_video(study):
img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
img = window(img)
X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
X = F.interpolate(X, size=(192, 192, 192), mode="nearest")
X = rescale(X)
with torch.no_grad():
seg_output = seg_model(X)
seg_output = torch.sigmoid(seg_output)
c_spine_coords = get_cervical_spine_coordinates(seg_output, img.shape)
chunk_features = []
for level, coords in c_spine_coords.items():
z1, z2, h1, h2, w1, w2 = coords
X = torch.from_numpy(img[z1:z2+1, h1:h2+1, w1:w2+1]).float().unsqueeze(0).unsqueeze(0)
X = F.interpolate(X, size=(64, 288, 288), mode="nearest")
X = rescale(X)
with torch.no_grad():
chunk_features.append(x3d_model.extract_features(X))
chunk_features = torch.stack(chunk_features, dim=1)
with torch.no_grad():
final_output = torch.sigmoid(seq_model((chunk_features, torch.ones((chunk_features.size(1), )))))
final_output_dict = {f"C{i+1}": final_output[:, i].item() for i in range(7)}
final_output_dict["Overall"] = final_output[:, -1].item()
seg_output = F.interpolate(seg_output, size=img.shape, mode="nearest").squeeze(0).numpy()
# shape = (8, Z, H, W)
p_spine = seg_output[:7].sum(0)
# shape = (Z, H, W)
seg_output = np.argmax(seg_output, dim=0) + 1
# shape = (Z, H, W)
seg_output[p_spine < 0.5] = 0
seg_output = (seg_output * 255 / 7).astype("uint8")
seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
seg_output[p_spine < 0.5] = 0
frames = []
skip = 8
for idx in range(0, img.shape[2], skip):
i = img[:, :, idx]
o = seg_output[:, :, idx]
i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
frame = np.concatenate((i, o), 1)
frames.append(frame)
mediapy.write_video("video.mp4", frames, fps=10)
return "video.mp4", final_output_dict
ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)
config = OmegaConf.load("configs/pseudoseg000.yaml")
config.model.load_pretrained = "seg.ckpt"
config.model.params.encoder_params.pretrained = False
seg_model = builder.build_model(config).eval()
config = OmegaConf.load("configs/chunk000.yaml")
config.model.load_pretrained = "x3d.ckpt"
config.model.params.pretrained = False
x3d_model = builder.build_model(config).eval()
config = OmegaConf.load("configs/chunkseq003.yaml")
config.model.load_pretrained = "seq.ckpt"
seq_model = builder.build_model(config).eval()
examples = glob.glob("examples/*.nii.gz")
with gr.Blocks(theme="dark-peach") as demo:
select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
button_predict = gr.Button("Predict")
video_output = gr.Video(label="Cervical Spine Segmentation")
label_output = gr.Label(label="Fracture Predictions", show_label=False)
button_predict.click(fn=generate_segmentation_video,
inputs=select_study,
outputs=[video_output, label_output])
if __name__ == "__main__":
demo.launch(debug=True)