segvol / README.md
yuxindu's picture
Update README.md
28f21da verified
|
raw
history blame
3.81 kB
metadata
license: mit
language:
  - en
pipeline_tag: image-segmentation
tags:
  - medical

image/jpeg

The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts point, box, and text prompts while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.

Keywords: 3D medical SAM, volumetric image segmentation

Quicktart

Requirements

conda create -n segvol_transformers python=3.8
conda activate segvol_transformers

The pytorch v1.11.0 (or a higher version) is needed. Following install key requirements using commands:

pip install 'monai[all]==0.9.0'
pip install einops==0.6.1
pip install transformers==4.18.0
pip install matplotlib

Test script

from transformers import AutoModel, AutoTokenizer
import torch

# get device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = 'cpu'

# load model
# If you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=True)
model.model.text_encoder.tokenizer = clip_tokenizer
model.eval()
model.to(device)
print('model load done')

# set case path
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'

# set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
categories = ["liver", "kidney", "spleen", "pancreas"]

# generate npy data format
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)

# go through zoom_transform to generate zoomout & zoomin views
data_item = model.processor.zoom_transform(ct_npy, gt_npy)

# add batch dim manually
data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)

# take liver as the example
cls_idx = 0

# text prompt
text_prompt = [categories[cls_idx]]

# point prompt
point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)   # inputs w/o batch dim, outputs w batch dim

# bbox prompt
bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)   # inputs w/o batch dim, outputs w batch dim

print('prompt done')

# segvol test forward
# use_zoom: use zoom-out-zoom-in
# point_prompt_group: use point prompt
# bbox_prompt_group: use bbox prompt
# text_prompt: use text prompt
logits_mask = model.forward_test(image=data_item['image'],
      zoomed_image=data_item['zoom_out_image'],
      # point_prompt_group=[point_prompt, point_prompt_map],
      bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
      text_prompt=text_prompt,
      use_zoom=False
      )

# cal dice score
dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx])
print(dice)

# save prediction as nii.gz file
save_path='./Case_preds_00001.nii.gz'
model.processor.save_preds(ct_path, save_path, logits_mask[0][0], 
                           start_coord=data_item['foreground_start_coord'], 
                           end_coord=data_item['foreground_end_coord'])
print('done')

Training script