xuhw20 commited on
Commit
94de731
·
verified ·
1 Parent(s): 413da6b

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +2 -8
  2. __pycache__/ct_seg.cpython-39.pyc +0 -0
  3. ct_seg.py +316 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Uw Ct Seg
3
- emoji: 👁
4
- colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: uw-ct-seg
3
+ app_file: ct_seg.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.44.0
 
 
6
  ---
 
 
__pycache__/ct_seg.cpython-39.pyc ADDED
Binary file (9.94 kB). View file
 
ct_seg.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BiomedSeg
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Yu Gu (yugu1@microsoft.com), Theo Zhao (theodorezhao@microsoft.com)
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import sys
10
+
11
+ this_file_dir = os.path.dirname(os.path.abspath(__file__))
12
+ sys.path.append(os.path.join(this_file_dir, "../ct_seg"))
13
+ import json
14
+ import warnings
15
+ import PIL
16
+ from PIL import Image
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple
18
+
19
+ import monai
20
+ import cv2
21
+ import math
22
+ import gradio as gr
23
+ import torch
24
+ import argparse
25
+ import imageio
26
+ import numpy as np
27
+ import scipy
28
+
29
+ from torchvision import transforms
30
+ from models import dinov2_vitl_transunet
31
+ from class_dict import class_dict, dataset_class
32
+ from transforms import _MEAN, _STD
33
+ from monai import transforms as monai_transforms
34
+ from scipy.ndimage import label
35
+
36
+ id2label = {v: k for k, v in class_dict.items()}
37
+ np.random.seed(0)
38
+ id2color = {k: list(np.random.choice(range(256), size=3)) for k,v in id2label.items()}
39
+
40
+
41
+ def clean_mask(X):
42
+ """
43
+ Cleans the mask for labels 1 and 2 by keeping only the largest connected component for each label.
44
+
45
+ Parameters:
46
+ X (numpy.ndarray): Volumetric mask of shape [N, 1, W, H] with values 0 (background), 1, or 2.
47
+
48
+ Returns:
49
+ numpy.ndarray: Cleaned volumetric mask with the same shape as X.
50
+ """
51
+ # Extract the volume data (assuming N is the depth dimension)
52
+ if X.ndim == 4:
53
+ volume = X[:, 0, :, :] # Shape: [N, W, H]
54
+ else:
55
+ volume = X
56
+
57
+ for label_value in [1, 2, 10]:
58
+ # Create a binary mask for the current label
59
+ mask = (volume == label_value)
60
+ if not np.any(mask):
61
+ continue # Skip if the label is not present
62
+
63
+ # Define connectivity for 3D connected components
64
+ structure = np.ones((3, 3, 3), dtype=int)
65
+
66
+ # Label connected components
67
+ labeled_mask, num_features = label(mask, structure=structure)
68
+ if num_features == 0:
69
+ continue # No connected components found
70
+
71
+ # Compute sizes of all connected components
72
+ component_sizes = np.bincount(labeled_mask.ravel())
73
+ component_sizes[0] = 0 # Ignore the background
74
+
75
+ # Find the label of the largest connected component
76
+ largest_component_label = component_sizes.argmax()
77
+
78
+ # Create a mask for the largest connected component
79
+ largest_component_mask = (labeled_mask == largest_component_label)
80
+
81
+ # Remove all other components of the current label
82
+ volume[mask] = 0 # Set all pixels of the current label to background
83
+ volume[largest_component_mask] = label_value # Restore the largest component
84
+
85
+ # Update the original mask
86
+ if X.ndim == 4:
87
+ X[:, 0, :, :] = volume
88
+ else:
89
+ X = volume
90
+ return X
91
+
92
+
93
+ def parse_option():
94
+ parser = argparse.ArgumentParser('SEEM Demo', add_help=False)
95
+ parser.add_argument('--model_path', default="ckpt/model_19.pth", metavar="FILE", help='path to model file')
96
+ # parser.add_argument('--model_path', default="ckpt/uw_seg_heart.pth", metavar="FILE", help='path to model file')
97
+ cfg = parser.parse_args()
98
+ return cfg
99
+
100
+ '''
101
+ build args
102
+ '''
103
+ cfg = parse_option()
104
+
105
+ pretrained_pth = cfg.model_path
106
+
107
+ def load_tif_images(file_path):
108
+ vol = imageio.imread(file_path)
109
+ if np.max(vol) <= 1:
110
+ vol = vol * 255
111
+ return vol
112
+
113
+ def overlay_image_with_mask(image, segmentation_map, path='test.png', ax=None):
114
+ color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
115
+ for label, color in id2color.items():
116
+ color_seg[segmentation_map == label, :] = color
117
+
118
+ # Show image + mask
119
+ img = np.array(image) * 0.5 + color_seg * 0.5
120
+ img = img.astype(np.uint8)
121
+ return img
122
+
123
+ def resize_volume(vol, size, max_frames, nearest_neighbor=False):
124
+ W, H, F = vol.shape
125
+
126
+ zoom_rate = size / W
127
+ vol_reshape = scipy.ndimage.zoom(
128
+ vol, (zoom_rate, zoom_rate, zoom_rate), order=3 if not nearest_neighbor else 0
129
+ )
130
+ resizeW, resizeH, resizeF = vol_reshape.shape
131
+ if resizeF > max_frames:
132
+ vol_reshape = vol_reshape[:, :, :max_frames]
133
+ resizeF = max_frames
134
+ else:
135
+ resized_max_fr = int(math.ceil(max_frames * zoom_rate))
136
+ vol_reshape = np.concatenate([vol_reshape, np.zeros((resizeW, resizeH, resized_max_fr - resizeF))], axis=-1)
137
+ return vol_reshape, resizeF, zoom_rate
138
+
139
+ val_transform = monai_transforms.Compose([monai_transforms.Resized(keys=['image'], spatial_size=(256, 256), mode=['bilinear'])])
140
+ def process_volume(vol: np.ndarray, keep_frames: Callable=lambda x: x > 0.025):
141
+ initial_resize = monai.transforms.ResizeWithPadOrCrop((512, 512))
142
+ transform = monai.transforms.CropForeground(keys=["pixel_values"], source_key="pixel_values", return_coords=True)
143
+ crop_vol, start_coords, end_coords = transform(vol)
144
+ keep_frames = np.where(keep_frames(np.mean(np.mean(crop_vol, axis=-1), axis=-1)))[0]
145
+ crop_vol = crop_vol[keep_frames]
146
+ W, H, F = crop_vol.shape
147
+ proc_vol = cv2.equalizeHist(crop_vol.reshape(W, -1).astype(np.uint8)).reshape(W, H, F)
148
+ proc_vol = initial_resize(proc_vol).detach().cpu().numpy().transpose((1, 2, 0))
149
+ proc_vol, max_fr = resize_volume(proc_vol, 256, max_frames=512)[:2]
150
+
151
+ images = []
152
+ for i in range(proc_vol.shape[2]):
153
+ image = torch.from_numpy(proc_vol[:, :, i]).unsqueeze(0)
154
+ image_transformed = val_transform({"image": image})["image"]
155
+ images.append(image_transformed)
156
+ images = torch.stack(images)
157
+ if images.max() > 1:
158
+ images = images / 255.0
159
+ # make the images three channels
160
+ images = images.repeat(1, 3, 1, 1)
161
+ for c in range(len(_MEAN)):
162
+ images[:, c, :, :] = (images[:, c, :, :] - _MEAN[c]) / _STD[c]
163
+ return images, max_fr
164
+
165
+ def untransform(img):
166
+ for c in range(len(_MEAN)):
167
+ img[c] = img[c] * _STD[c] + _MEAN[c]
168
+ if img.max() <= 1:
169
+ img = img * 255
170
+ return img.long()
171
+
172
+ def process_ct(ct_path: str):
173
+ vol = load_tif_images(ct_path)
174
+ images, frame_indices = process_volume(vol, keep_frames=lambda x: x > 0.025)
175
+ return images, frame_indices
176
+
177
+ # Ensure the example file is in the same directory or provide a relative path
178
+ examples = [["demo/CTseg_57_raw.tif"],
179
+ ["demo/CTrec-don_1101.tif"]]
180
+
181
+ '''
182
+ build model
183
+ '''
184
+ class_names = dataset_class["uwseg"]
185
+ class_ids = [class_dict[class_name] for class_name in class_names]
186
+ model = dinov2_vitl_transunet(pretrained="", num_classes=len(class_dict), img_size=256)
187
+ state_dict = torch.load(pretrained_pth)
188
+ model.load_state_dict(state_dict)
189
+ model = model.cuda()
190
+
191
+ @torch.no_grad()
192
+ def inference(image_input):
193
+ if isinstance(image_input, str):
194
+ # image_input is a file path
195
+ file_path = image_input
196
+ else:
197
+ # image_input is a gr.File object
198
+ file_path = image_input.name
199
+ images, frame_indices = process_ct(file_path)
200
+ with torch.no_grad():
201
+ with torch.cuda.amp.autocast(dtype=torch.float16):
202
+ logits = model(images.cuda())
203
+ for j in range(len(class_dict)):
204
+ if j + 1 not in class_ids:
205
+ logits[:, j] = -1000
206
+ pred = torch.argmax(logits, dim=1) + 1
207
+ pred_mask = (torch.max(logits, dim=1)[0] > 0)
208
+ pred = pred_mask * pred
209
+ pred[frame_indices:] = 0
210
+ pred = torch.from_numpy(clean_mask(pred.cpu().numpy()))
211
+ volume_size = torch.sum(pred==2).item()
212
+ # 1 pixel = 1 mm^2, change to cm^3
213
+ volume_size = volume_size / 1000
214
+
215
+ # Compute the size of the segmented mask for each slice
216
+ sizes = pred.view(pred.shape[0], -1).sum(dim=1).cpu().numpy()
217
+
218
+ segmentation_results = []
219
+ raw_images = []
220
+ for i in range(len(images)):
221
+ images[i] = untransform(images[i])
222
+ raw_image = Image.fromarray(images[i].cpu().permute(1, 2, 0).numpy().astype(np.uint8))
223
+ raw_images.append(raw_image)
224
+ image_with_mask = overlay_image_with_mask(images[i].cpu().permute(1, 2, 0).numpy(), pred[i].squeeze(0).cpu().numpy())
225
+ image_with_mask = Image.fromarray(image_with_mask)
226
+ segmentation_results.append(image_with_mask)
227
+ initial_slice_index = 0
228
+ output_seg = segmentation_results[initial_slice_index]
229
+ output_raw = raw_images[initial_slice_index]
230
+ num_slices = len(segmentation_results)
231
+ initial_size = sizes[initial_slice_index]
232
+ return output_seg, output_raw, segmentation_results, raw_images, gr.update(maximum=num_slices - 1), sizes, f"Heart volume size: {volume_size} cm^3"
233
+
234
+ def update_slice(slice_index, segmentation_results_state, raw_images_state, sizes_text):
235
+ segmentation_results = segmentation_results_state
236
+ raw_images = raw_images_state
237
+
238
+ if segmentation_results is None or raw_images is None:
239
+ return None, None, ""
240
+ output_seg = segmentation_results[slice_index]
241
+ output_raw = raw_images[slice_index]
242
+
243
+ return output_seg, output_raw, size_text
244
+
245
+ def load_example(example):
246
+ image_file_path = example
247
+ return inference(image_file_path)
248
+
249
+ title = "CT Segmentation"
250
+ description = """
251
+
252
+ <div style="text-align: left; font-weight: bold;">
253
+ <br>
254
+ &#x1F32A Note: The current model is run on <span style="color:blue;">CT Segmentation (UW) </span> </p>
255
+ </div>
256
+ """
257
+
258
+ article = "The Demo is Run on CT-Seg."
259
+ with gr.Blocks(theme=gr.themes.Soft(), title=title, css=".gradio-container { max-width: 1000px; margin: auto; }") as demo:
260
+ # add title
261
+ with gr.Row():
262
+ gr.Markdown(value="# <span style='color: #6366f1;'>UW CT segmentation</span>", elem_id="title")
263
+ with gr.Row():
264
+ with gr.Column(scale=2):
265
+ gr.Markdown(value="""
266
+ Welcome to CT Segmentation, an AI model that segments the thorax and heart out, and computes the volume sizes.
267
+
268
+ ## How to Use:
269
+ 0. **Explore Default Examples**: Click on images in the right panel.
270
+ 1. **Upload Your Image**: something biomedical... but not your lovely pet!
271
+
272
+ Click **Segment** and see what CT Seg finds for you!
273
+ """,
274
+ elem_id="instructions")
275
+ gr.Markdown("## Step 1: Upload CT volume .tif image (Try examples on the right panel)")
276
+ with gr.Row(equal_height = True):
277
+ input_image = gr.File(label="Input Image", file_types=[".tif"])
278
+ # Initially, set the slider maximum to a default value, e.g., 0
279
+ slice_index_slider = gr.Slider(minimum=0, maximum=0, step=1, label="Slice Index")
280
+ with gr.Row(equal_height = True):
281
+ output_raw = gr.Image(label="Processed Image", interactive=False)
282
+ output_seg = gr.Image(label="Segmentation Results", interactive=False)
283
+ with gr.Row():
284
+ size_text = gr.Textbox(label="Heart volume Size", interactive=False)
285
+ with gr.Row():
286
+ button = gr.Button("Segment", interactive=True, variant='primary')
287
+ with gr.Column(scale=0.5):
288
+ gr.Markdown("## Click Default Examples")
289
+ # Initialize state variables
290
+ segmentation_results_state = gr.State()
291
+ raw_images_state = gr.State()
292
+ sizes_state = gr.State()
293
+ gr.Examples(
294
+ examples=examples,
295
+ inputs=[input_image],
296
+ outputs=[output_seg, output_raw, segmentation_results_state, raw_images_state, slice_index_slider, sizes_state, size_text],
297
+ fn=load_example,
298
+ cache_examples=False,
299
+ examples_per_page=1,
300
+ run_on_click=True
301
+ )
302
+ # Set up the button click
303
+ button.click(
304
+ fn=inference,
305
+ inputs=[input_image],
306
+ outputs=[output_seg, output_raw, segmentation_results_state, raw_images_state, slice_index_slider, sizes_state, size_text]
307
+ )
308
+ # Set up the slider change
309
+ slice_index_slider.change(
310
+ fn=update_slice,
311
+ inputs=[slice_index_slider, segmentation_results_state, raw_images_state, size_text],
312
+ outputs=[output_seg, output_raw, size_text]
313
+ )
314
+
315
+ if __name__ == "__main__":
316
+ demo.queue().launch(share=True)