File size: 5,997 Bytes
d306094
 
 
 
 
 
 
11f63f1
 
0cd23a0
 
 
 
 
d812c73
 
0cd23a0
108b9be
 
 
 
6d29dfb
 
 
 
 
 
 
 
 
 
 
 
 
108b9be
 
 
6d29dfb
48504ce
 
6d29dfb
48504ce
ebd4014
48504ce
 
aeee2f4
c87ba1c
 
48504ce
 
 
 
 
 
222046a
28f21da
 
48504ce
 
 
 
 
 
aeee2f4
48504ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d29dfb
 
108b9be
2568f65
 
 
 
 
 
ebd4014
2568f65
 
 
994e313
 
2568f65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
---
license: mit
language:
- en
pipeline_tag: image-segmentation
tags:
- medical
---


![image/jpeg](https://cdn-uploads.huggingface.co/production/uploads/6565b54a9bf6665f10f75441/no60wyvKDTD-WV3pCt2P5.jpeg)

The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point, box, and text prompts** while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.

[Paper](https://arxiv.org/abs/2311.13385) and [Code](https://github.com/BAAI-DCAI/SegVol) have been released.

**Keywords**: 3D medical SAM, volumetric image segmentation

## Quicktart

### Requirements
```
conda create -n segvol_transformers python=3.8
conda activate segvol_transformers
```

The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or a higher version) is needed. Following install key requirements using commands:

```
pip install 'monai[all]==0.9.0'
pip install einops==0.6.1
pip install transformers==4.18.0
pip install matplotlib
```

### Test script

```python
from transformers import AutoModel, AutoTokenizer
import torch

# get device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load model
# IF you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=True)
model.model.text_encoder.tokenizer = clip_tokenizer
model.eval()
model.to(device)
print('model load done')

# set case path
# you can download this case from huggingface yuxindu/segvol files and versions
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'

# set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
categories = ["liver", "kidney", "spleen", "pancreas"]

# generate npy data format
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
# IF you have download our 25 processed datasets, you can skip to here with the processed ct_npy, gt_npy files

# go through zoom_transform to generate zoomout & zoomin views
data_item = model.processor.zoom_transform(ct_npy, gt_npy)

# add batch dim manually
data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)

# take liver as the example
cls_idx = 0

# text prompt
text_prompt = [categories[cls_idx]]

# point prompt
point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)   # inputs w/o batch dim, outputs w batch dim

# bbox prompt
bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)   # inputs w/o batch dim, outputs w batch dim

print('prompt done')

# segvol test forward
# use_zoom: use zoom-out-zoom-in
# point_prompt_group: use point prompt
# bbox_prompt_group: use bbox prompt
# text_prompt: use text prompt
logits_mask = model.forward_test(image=data_item['image'],
      zoomed_image=data_item['zoom_out_image'],
      # point_prompt_group=[point_prompt, point_prompt_map],
      bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
      text_prompt=text_prompt,
      use_zoom=False
      )

# cal dice score
dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx])
print(dice)

# save prediction as nii.gz file
save_path='./Case_preds_00001.nii.gz'
model.processor.save_preds(ct_path, save_path, logits_mask[0][0], 
                           start_coord=data_item['foreground_start_coord'], 
                           end_coord=data_item['foreground_end_coord'])
print('done')
```

### Training script

```python
from transformers import AutoModel, AutoTokenizer
import torch

# get device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load model
# IF you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=False)
model.model.text_encoder.tokenizer = clip_tokenizer
model.train()
model.to(device)
print('model load done')

# set case path
# you can download this case from huggingface yuxindu/segvol files and versions
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'

# set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
categories = ["liver", "kidney", "spleen", "pancreas"]

# generate npy data format
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
# IF you have download our 25 processed datasets, you can skip to here with the processed ct_npy, gt_npy files

# go through train transform
data_item = model.processor.train_transform(ct_npy, gt_npy)

# training example
# add batch dim manually
image, gt3D = data_item["image"].unsqueeze(0).to(device), data_item["label"].unsqueeze(0).to(device) # add batch dim

loss_step_avg = 0
for cls_idx in range(len(categories)):
    # optimizer.zero_grad()
    organs_cls = categories[cls_idx]
    labels_cls = gt3D[:, cls_idx]
    print(image.shape, organs_cls, labels_cls.shape)
    loss = model.forward_train(image, train_organs=organs_cls, train_labels=labels_cls)
    loss_step_avg += loss.item()
    loss.backward()
    # optimizer.step()

loss_step_avg /= len(categories)
print(f'AVG loss {loss_step_avg}')

# save ckpt
model.save_pretrained('./ckpt')
```