File size: 3,814 Bytes
d306094
 
 
 
 
 
 
11f63f1
 
0cd23a0
 
 
 
 
 
108b9be
 
 
 
6d29dfb
 
 
 
 
 
 
 
 
 
 
 
 
108b9be
 
 
6d29dfb
48504ce
 
6d29dfb
48504ce
 
 
 
 
c87ba1c
 
 
48504ce
 
 
 
 
 
28f21da
 
48504ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d29dfb
 
108b9be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
---
license: mit
language:
- en
pipeline_tag: image-segmentation
tags:
- medical
---


![image/jpeg](https://cdn-uploads.huggingface.co/production/uploads/6565b54a9bf6665f10f75441/no60wyvKDTD-WV3pCt2P5.jpeg)

The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point, box, and text prompts** while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.

**Keywords**: 3D medical SAM, volumetric image segmentation

## Quicktart

### Requirements
```
conda create -n segvol_transformers python=3.8
conda activate segvol_transformers
```

The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or a higher version) is needed. Following install key requirements using commands:

```
pip install 'monai[all]==0.9.0'
pip install einops==0.6.1
pip install transformers==4.18.0
pip install matplotlib
```

### Test script

```python
from transformers import AutoModel, AutoTokenizer
import torch

# get device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = 'cpu'

# load model
# If you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=True)
model.model.text_encoder.tokenizer = clip_tokenizer
model.eval()
model.to(device)
print('model load done')

# set case path
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'

# set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
categories = ["liver", "kidney", "spleen", "pancreas"]

# generate npy data format
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)

# go through zoom_transform to generate zoomout & zoomin views
data_item = model.processor.zoom_transform(ct_npy, gt_npy)

# add batch dim manually
data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)

# take liver as the example
cls_idx = 0

# text prompt
text_prompt = [categories[cls_idx]]

# point prompt
point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)   # inputs w/o batch dim, outputs w batch dim

# bbox prompt
bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)   # inputs w/o batch dim, outputs w batch dim

print('prompt done')

# segvol test forward
# use_zoom: use zoom-out-zoom-in
# point_prompt_group: use point prompt
# bbox_prompt_group: use bbox prompt
# text_prompt: use text prompt
logits_mask = model.forward_test(image=data_item['image'],
      zoomed_image=data_item['zoom_out_image'],
      # point_prompt_group=[point_prompt, point_prompt_map],
      bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
      text_prompt=text_prompt,
      use_zoom=False
      )

# cal dice score
dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx])
print(dice)

# save prediction as nii.gz file
save_path='./Case_preds_00001.nii.gz'
model.processor.save_preds(ct_path, save_path, logits_mask[0][0], 
                           start_coord=data_item['foreground_start_coord'], 
                           end_coord=data_item['foreground_end_coord'])
print('done')
```

### Training script