Spaces:
Runtime error
Runtime error
File size: 4,497 Bytes
231edce b0fbed3 231edce b0fbed3 c0cba08 b0fbed3 231edce b0fbed3 231edce b0fbed3 231edce e7cf218 231edce b0fbed3 231edce b0fbed3 231edce b0fbed3 231edce e7cf218 231edce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import cv2
import glob
import gradio as gr
import mediapy
import nibabel
import numpy as np
import shutil
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from skp import builder
def window(x, WL=400, WW=2500):
lower, upper = WL - WW // 2, WL + WW // 2
x = np.clip(x, lower, upper)
x = x - lower
x = x / (upper - lower)
return (x * 255).astype("uint8")
def rescale(x):
x = x / 255.
x = x - 0.5
x = x * 2.0
return x
def get_cervical_spine_coordinates(x, original_shape):
# Assumes x is torch tensor with shape (1, 8, Z, H, W)
x = x.squeeze(0).numpy()[:7]
rescale_factor = [original_shape[0] / x.shape[1], original_shape[1] / x.shape[2], original_shape[2] / x.shape[3]]
coords_dict = {}
for level in range(x.shape[0]):
coords = np.where(x[level] >= 0.4)
coords = np.vstack(coords).astype("float")
coords[0] = coords[0] * rescale_factor[0]
coords[1] = coords[1] * rescale_factor[1]
coords[2] = coords[2] * rescale_factor[2]
coords = coords.astype("int")
coords_dict[level] = coords[0].min(), coords[0].max(),\
coords[1].min(), coords[1].max(),\
coords[2].min(), coords[2].max()
return coords_dict
def generate_segmentation_video(study):
img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
img = window(img)
X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
X = F.interpolate(X, size=(192, 192, 192), mode="nearest")
X = rescale(X)
with torch.no_grad():
seg_output = seg_model(X)
seg_output = torch.sigmoid(seg_output)
c_spine_coords = get_cervical_spine_coordinates(seg_output, img.shape)
chunk_features = []
for level, coords in c_spine_coords.items():
z1, z2, h1, h2, w1, w2 = coords
X = torch.from_numpy(img[z1:z2+1, h1:h2+1, w1:w2+1]).float().unsqueeze(0).unsqueeze(0)
X = F.interpolate(X, size=(64, 288, 288), mode="nearest")
X = rescale(X)
with torch.no_grad():
chunk_features.append(x3d_model.extract_features(X))
chunk_features = torch.stack(chunk_features, dim=1)
with torch.no_grad():
final_output = torch.sigmoid(seq_model((chunk_features, torch.ones((chunk_features.size(1), )))))
final_output_dict = {f"C{i+1}": final_output[:, i].item() for i in range(7)}
final_output_dict["Overall"] = final_output[:, -1].item()
seg_output = F.interpolate(seg_output, size=img.shape, mode="nearest").squeeze(0).numpy()
# shape = (8, Z, H, W)
p_spine = seg_output[:7].sum(0)
# shape = (Z, H, W)
seg_output = np.argmax(seg_output[:7], axis=0) + 1
# shape = (Z, H, W)
seg_output[p_spine < 0.5] = 0
seg_output = (seg_output * 255 / 7).astype("uint8")
seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
seg_output[p_spine < 0.5] = 0
frames = []
skip = 8
for idx in range(0, img.shape[2], skip):
i = img[:, :, idx]
o = seg_output[:, :, idx]
i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
frame = np.concatenate((i, o), 1)
frames.append(frame)
mediapy.write_video("video.mp4", frames, fps=10)
return "video.mp4", final_output_dict
ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)
config = OmegaConf.load("configs/pseudoseg000.yaml")
config.model.load_pretrained = "seg.ckpt"
config.model.params.encoder_params.pretrained = False
seg_model = builder.build_model(config).eval()
config = OmegaConf.load("configs/chunk000.yaml")
config.model.load_pretrained = "x3d.ckpt"
config.model.params.pretrained = False
x3d_model = builder.build_model(config).eval()
config = OmegaConf.load("configs/chunkseq003.yaml")
config.model.load_pretrained = "seq.ckpt"
seq_model = builder.build_model(config).eval()
examples = glob.glob("examples/*.nii.gz")
with gr.Blocks(theme="dark-peach") as demo:
select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
button_predict = gr.Button("Predict")
video_output = gr.Video(label="Cervical Spine Segmentation")
label_output = gr.Label(label="Fracture Predictions", show_label=False)
button_predict.click(fn=generate_segmentation_video,
inputs=select_study,
outputs=[video_output, label_output])
if __name__ == "__main__":
demo.launch(debug=True)
|