Spaces:
Runtime error
Runtime error
Add fracture inference
Browse files
app.py
CHANGED
@@ -27,6 +27,24 @@ def rescale(x):
|
|
27 |
return x
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def generate_segmentation_video(study):
|
31 |
img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
|
32 |
img = window(img)
|
@@ -38,13 +56,34 @@ def generate_segmentation_video(study):
|
|
38 |
seg_output = seg_model(X)
|
39 |
|
40 |
seg_output = torch.sigmoid(seg_output)
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
seg_output[p_spine < 0.5] = 0
|
44 |
-
seg_output = F.interpolate(seg_output.unsqueeze(0).float(), size=img.shape, mode="nearest")
|
45 |
-
seg_output = seg_output.squeeze(0).squeeze(0).numpy()
|
46 |
seg_output = (seg_output * 255 / 7).astype("uint8")
|
47 |
seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
|
|
|
48 |
|
49 |
frames = []
|
50 |
skip = 8
|
@@ -54,8 +93,8 @@ def generate_segmentation_video(study):
|
|
54 |
i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
|
55 |
frame = np.concatenate((i, o), 1)
|
56 |
frames.append(frame)
|
57 |
-
mediapy.write_video("video.mp4", frames, fps=
|
58 |
-
return "video.mp4"
|
59 |
|
60 |
|
61 |
ffmpeg_path = shutil.which('ffmpeg')
|
@@ -65,15 +104,26 @@ config = OmegaConf.load("configs/pseudoseg000.yaml")
|
|
65 |
config.model.load_pretrained = "seg.ckpt"
|
66 |
config.model.params.encoder_params.pretrained = False
|
67 |
seg_model = builder.build_model(config).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
examples = glob.glob("examples/*.nii.gz")
|
69 |
|
70 |
with gr.Blocks(theme="dark-peach") as demo:
|
71 |
select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
|
72 |
button_predict = gr.Button("Predict")
|
73 |
-
video_output = gr.Video()
|
|
|
74 |
button_predict.click(fn=generate_segmentation_video,
|
75 |
inputs=select_study,
|
76 |
-
outputs=video_output)
|
77 |
|
78 |
|
79 |
if __name__ == "__main__":
|
|
|
27 |
return x
|
28 |
|
29 |
|
30 |
+
def get_cervical_spine_coordinates(x, original_shape):
|
31 |
+
# Assumes x is torch tensor with shape (1, 8, Z, H, W)
|
32 |
+
x = x.squeeze(0).numpy()[:7]
|
33 |
+
rescale_factor = [original_shape[0] / x.shape[1], original_shape[1] / x.shape[2], original_shape[2] / x.shape[3]]
|
34 |
+
coords_dict = {}
|
35 |
+
for level in range(x.shape[0]):
|
36 |
+
coords = np.where(x[level] >= 0.4)
|
37 |
+
coords = np.vstack(coords).astype("float")
|
38 |
+
coords[0] = coords[0] * rescale_factor[0]
|
39 |
+
coords[1] = coords[1] * rescale_factor[1]
|
40 |
+
coords[2] = coords[2] * rescale_factor[2]
|
41 |
+
coords = coords.astype("int")
|
42 |
+
coords_dict[level] = coords[0].min(), coords[0].max(),\
|
43 |
+
coords[1].min(), coords[1].max(),\
|
44 |
+
coords[2].min(), coords[2].max()
|
45 |
+
return coords_dict
|
46 |
+
|
47 |
+
|
48 |
def generate_segmentation_video(study):
|
49 |
img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
|
50 |
img = window(img)
|
|
|
56 |
seg_output = seg_model(X)
|
57 |
|
58 |
seg_output = torch.sigmoid(seg_output)
|
59 |
+
c_spine_coords = get_cervical_spine_coordinates(seg_output, img.shape)
|
60 |
+
|
61 |
+
chunk_features = []
|
62 |
+
for level, coords in c_spine_coords.items():
|
63 |
+
z1, z2, h1, h2, w1, w2 = coords
|
64 |
+
X = torch.from_numpy(img[z1:z2+1, h1:h2+1, w1:w2+1]).float().unsqueeze(0).unsqueeze(0)
|
65 |
+
X = F.interpolate(X, size=(64, 288, 288), mode="nearest")
|
66 |
+
X = rescale(X)
|
67 |
+
with torch.no_grad():
|
68 |
+
chunk_features.append(x3d_model.extract_features(X))
|
69 |
+
|
70 |
+
chunk_features = torch.stack(chunk_features, dim=1)
|
71 |
+
with torch.no_grad():
|
72 |
+
final_output = torch.sigmoid(seq_model((chunk_features, torch.ones((chunk_features.size(1), )))))
|
73 |
+
|
74 |
+
final_output_dict = {f"C{i+1}": final_output[:, i].item() for i in range(7)}
|
75 |
+
final_output_dict["Overall"] = final_output[:, -1].item()
|
76 |
+
|
77 |
+
seg_output = F.interpolate(seg_output, size=img.shape, mode="nearest").squeeze(0).numpy()
|
78 |
+
# shape = (8, Z, H, W)
|
79 |
+
p_spine = seg_output[:7].sum(0)
|
80 |
+
# shape = (Z, H, W)
|
81 |
+
seg_output = np.argmax(seg_output, dim=0) + 1
|
82 |
+
# shape = (Z, H, W)
|
83 |
seg_output[p_spine < 0.5] = 0
|
|
|
|
|
84 |
seg_output = (seg_output * 255 / 7).astype("uint8")
|
85 |
seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
|
86 |
+
seg_output[p_spine < 0.5] = 0
|
87 |
|
88 |
frames = []
|
89 |
skip = 8
|
|
|
93 |
i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
|
94 |
frame = np.concatenate((i, o), 1)
|
95 |
frames.append(frame)
|
96 |
+
mediapy.write_video("video.mp4", frames, fps=10)
|
97 |
+
return "video.mp4", final_output_dict
|
98 |
|
99 |
|
100 |
ffmpeg_path = shutil.which('ffmpeg')
|
|
|
104 |
config.model.load_pretrained = "seg.ckpt"
|
105 |
config.model.params.encoder_params.pretrained = False
|
106 |
seg_model = builder.build_model(config).eval()
|
107 |
+
|
108 |
+
config = OmegaConf.load("configs/chunk000.yaml")
|
109 |
+
config.model.load_pretrained = "x3d.ckpt"
|
110 |
+
config.model.params.pretrained = False
|
111 |
+
x3d_model = builder.build_model(config).eval()
|
112 |
+
|
113 |
+
config = OmegaConf.load("configs/chunkseq003.yaml")
|
114 |
+
config.model.load_pretrained = "seq.ckpt"
|
115 |
+
seq_model = builder.build_model(config).eval()
|
116 |
+
|
117 |
examples = glob.glob("examples/*.nii.gz")
|
118 |
|
119 |
with gr.Blocks(theme="dark-peach") as demo:
|
120 |
select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
|
121 |
button_predict = gr.Button("Predict")
|
122 |
+
video_output = gr.Video(label="Cervical Spine Segmentation")
|
123 |
+
label_output = gr.Label(label="Fracture Predictions", show_label=False)
|
124 |
button_predict.click(fn=generate_segmentation_video,
|
125 |
inputs=select_study,
|
126 |
+
outputs=[video_output, label_output])
|
127 |
|
128 |
|
129 |
if __name__ == "__main__":
|