ianpan commited on
Commit
b0fbed3
1 Parent(s): e7cf218

Add fracture inference

Browse files
Files changed (1) hide show
  1. app.py +58 -8
app.py CHANGED
@@ -27,6 +27,24 @@ def rescale(x):
27
  return x
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def generate_segmentation_video(study):
31
  img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
32
  img = window(img)
@@ -38,13 +56,34 @@ def generate_segmentation_video(study):
38
  seg_output = seg_model(X)
39
 
40
  seg_output = torch.sigmoid(seg_output)
41
- p_spine = seg_output[:, :7].sum(1)
42
- seg_output = torch.argmax(seg_output, dim=1) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  seg_output[p_spine < 0.5] = 0
44
- seg_output = F.interpolate(seg_output.unsqueeze(0).float(), size=img.shape, mode="nearest")
45
- seg_output = seg_output.squeeze(0).squeeze(0).numpy()
46
  seg_output = (seg_output * 255 / 7).astype("uint8")
47
  seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
 
48
 
49
  frames = []
50
  skip = 8
@@ -54,8 +93,8 @@ def generate_segmentation_video(study):
54
  i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
55
  frame = np.concatenate((i, o), 1)
56
  frames.append(frame)
57
- mediapy.write_video("video.mp4", frames, fps=30)
58
- return "video.mp4"
59
 
60
 
61
  ffmpeg_path = shutil.which('ffmpeg')
@@ -65,15 +104,26 @@ config = OmegaConf.load("configs/pseudoseg000.yaml")
65
  config.model.load_pretrained = "seg.ckpt"
66
  config.model.params.encoder_params.pretrained = False
67
  seg_model = builder.build_model(config).eval()
 
 
 
 
 
 
 
 
 
 
68
  examples = glob.glob("examples/*.nii.gz")
69
 
70
  with gr.Blocks(theme="dark-peach") as demo:
71
  select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
72
  button_predict = gr.Button("Predict")
73
- video_output = gr.Video()
 
74
  button_predict.click(fn=generate_segmentation_video,
75
  inputs=select_study,
76
- outputs=video_output)
77
 
78
 
79
  if __name__ == "__main__":
 
27
  return x
28
 
29
 
30
+ def get_cervical_spine_coordinates(x, original_shape):
31
+ # Assumes x is torch tensor with shape (1, 8, Z, H, W)
32
+ x = x.squeeze(0).numpy()[:7]
33
+ rescale_factor = [original_shape[0] / x.shape[1], original_shape[1] / x.shape[2], original_shape[2] / x.shape[3]]
34
+ coords_dict = {}
35
+ for level in range(x.shape[0]):
36
+ coords = np.where(x[level] >= 0.4)
37
+ coords = np.vstack(coords).astype("float")
38
+ coords[0] = coords[0] * rescale_factor[0]
39
+ coords[1] = coords[1] * rescale_factor[1]
40
+ coords[2] = coords[2] * rescale_factor[2]
41
+ coords = coords.astype("int")
42
+ coords_dict[level] = coords[0].min(), coords[0].max(),\
43
+ coords[1].min(), coords[1].max(),\
44
+ coords[2].min(), coords[2].max()
45
+ return coords_dict
46
+
47
+
48
  def generate_segmentation_video(study):
49
  img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
50
  img = window(img)
 
56
  seg_output = seg_model(X)
57
 
58
  seg_output = torch.sigmoid(seg_output)
59
+ c_spine_coords = get_cervical_spine_coordinates(seg_output, img.shape)
60
+
61
+ chunk_features = []
62
+ for level, coords in c_spine_coords.items():
63
+ z1, z2, h1, h2, w1, w2 = coords
64
+ X = torch.from_numpy(img[z1:z2+1, h1:h2+1, w1:w2+1]).float().unsqueeze(0).unsqueeze(0)
65
+ X = F.interpolate(X, size=(64, 288, 288), mode="nearest")
66
+ X = rescale(X)
67
+ with torch.no_grad():
68
+ chunk_features.append(x3d_model.extract_features(X))
69
+
70
+ chunk_features = torch.stack(chunk_features, dim=1)
71
+ with torch.no_grad():
72
+ final_output = torch.sigmoid(seq_model((chunk_features, torch.ones((chunk_features.size(1), )))))
73
+
74
+ final_output_dict = {f"C{i+1}": final_output[:, i].item() for i in range(7)}
75
+ final_output_dict["Overall"] = final_output[:, -1].item()
76
+
77
+ seg_output = F.interpolate(seg_output, size=img.shape, mode="nearest").squeeze(0).numpy()
78
+ # shape = (8, Z, H, W)
79
+ p_spine = seg_output[:7].sum(0)
80
+ # shape = (Z, H, W)
81
+ seg_output = np.argmax(seg_output, dim=0) + 1
82
+ # shape = (Z, H, W)
83
  seg_output[p_spine < 0.5] = 0
 
 
84
  seg_output = (seg_output * 255 / 7).astype("uint8")
85
  seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
86
+ seg_output[p_spine < 0.5] = 0
87
 
88
  frames = []
89
  skip = 8
 
93
  i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
94
  frame = np.concatenate((i, o), 1)
95
  frames.append(frame)
96
+ mediapy.write_video("video.mp4", frames, fps=10)
97
+ return "video.mp4", final_output_dict
98
 
99
 
100
  ffmpeg_path = shutil.which('ffmpeg')
 
104
  config.model.load_pretrained = "seg.ckpt"
105
  config.model.params.encoder_params.pretrained = False
106
  seg_model = builder.build_model(config).eval()
107
+
108
+ config = OmegaConf.load("configs/chunk000.yaml")
109
+ config.model.load_pretrained = "x3d.ckpt"
110
+ config.model.params.pretrained = False
111
+ x3d_model = builder.build_model(config).eval()
112
+
113
+ config = OmegaConf.load("configs/chunkseq003.yaml")
114
+ config.model.load_pretrained = "seq.ckpt"
115
+ seq_model = builder.build_model(config).eval()
116
+
117
  examples = glob.glob("examples/*.nii.gz")
118
 
119
  with gr.Blocks(theme="dark-peach") as demo:
120
  select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
121
  button_predict = gr.Button("Predict")
122
+ video_output = gr.Video(label="Cervical Spine Segmentation")
123
+ label_output = gr.Label(label="Fracture Predictions", show_label=False)
124
  button_predict.click(fn=generate_segmentation_video,
125
  inputs=select_study,
126
+ outputs=[video_output, label_output])
127
 
128
 
129
  if __name__ == "__main__":