--- license: mit language: - en pipeline_tag: image-segmentation tags: - medical --- ![image/jpeg](https://cdn-uploads.huggingface.co/production/uploads/6565b54a9bf6665f10f75441/no60wyvKDTD-WV3pCt2P5.jpeg) The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point, box, and text prompts** while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories. **Keywords**: 3D medical SAM, volumetric image segmentation ## Quicktart ### Requirements ``` conda create -n segvol_transformers python=3.8 conda activate segvol_transformers ``` The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or a higher version) is needed. Following install key requirements using commands: ``` pip install 'monai[all]==0.9.0' pip install einops==0.6.1 pip install transformers==4.18.0 pip install matplotlib ``` ### Test script ```python from transformers import AutoModel, AutoTokenizer import torch # get device device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") device = 'cpu' # load model # If you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol" clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol") model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=True) model.model.text_encoder.tokenizer = clip_tokenizer model.eval() model.to(device) print('model load done') # set case path ct_path = 'path/to/Case_image_00001_0000.nii.gz' gt_path = 'path/to/Case_label_00001.nii.gz' # set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask categories = ["liver", "kidney", "spleen", "pancreas"] # generate npy data format ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories) # go through zoom_transform to generate zoomout & zoomin views data_item = model.processor.zoom_transform(ct_npy, gt_npy) # add batch dim manually data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \ data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device) # take liver as the example cls_idx = 0 # text prompt text_prompt = [categories[cls_idx]] # point prompt point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device) # inputs w/o batch dim, outputs w batch dim # bbox prompt bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device) # inputs w/o batch dim, outputs w batch dim print('prompt done') # segvol test forward # use_zoom: use zoom-out-zoom-in # point_prompt_group: use point prompt # bbox_prompt_group: use bbox prompt # text_prompt: use text prompt logits_mask = model.forward_test(image=data_item['image'], zoomed_image=data_item['zoom_out_image'], # point_prompt_group=[point_prompt, point_prompt_map], bbox_prompt_group=[bbox_prompt, bbox_prompt_map], text_prompt=text_prompt, use_zoom=False ) # cal dice score dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx]) print(dice) # save prediction as nii.gz file save_path='./Case_preds_00001.nii.gz' model.processor.save_preds(ct_path, save_path, logits_mask[0][0], start_coord=data_item['foreground_start_coord'], end_coord=data_item['foreground_end_coord']) print('done') ``` ### Training script