yuxindu commited on
Commit
48504ce
·
verified ·
1 Parent(s): 6d29dfb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +68 -0
README.md CHANGED
@@ -34,7 +34,75 @@ pip install matplotlib
34
  ### Test script
35
 
36
  ```python
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ```
39
 
40
  ### Training script
 
34
  ### Test script
35
 
36
  ```python
37
+ from transformers import AutoModel, AutoTokenizer
38
+ import torch
39
 
40
+ # get device
41
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
42
+ device = 'cpu'
43
+
44
+ # load model
45
+ clip_tokenizer = AutoTokenizer.from_pretrained("./segvol")
46
+ model = AutoModel.from_pretrained("./segvol", trust_remote_code=True, test_mode=True)
47
+ model.model.text_encoder.tokenizer = clip_tokenizer
48
+ model.eval()
49
+ model.to(device)
50
+ print('model load done')
51
+
52
+ # set case path
53
+ ct_path = './Case_image_00001_0000.nii.gz'
54
+ gt_path = './Case_label_00001.nii.gz'
55
+
56
+ # set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
57
+ categories = ["liver", "kidney", "spleen", "pancreas"]
58
+
59
+ # generate npy data format
60
+ ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
61
+
62
+ # go through zoom_transform to generate zoomout & zoomin views
63
+ data_item = model.processor.zoom_transform(ct_npy, gt_npy)
64
+
65
+ # add batch dim manually
66
+ data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
67
+ data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)
68
+
69
+ # take liver as the example
70
+ cls_idx = 0
71
+
72
+ # text prompt
73
+ text_prompt = [categories[cls_idx]]
74
+
75
+ # point prompt
76
+ point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device) # inputs w/o batch dim, outputs w batch dim
77
+
78
+ # bbox prompt
79
+ bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device) # inputs w/o batch dim, outputs w batch dim
80
+
81
+ print('prompt done')
82
+
83
+ # segvol test forward
84
+ # use_zoom: use zoom-out-zoom-in
85
+ # point_prompt_group: use point prompt
86
+ # bbox_prompt_group: use bbox prompt
87
+ # text_prompt: use text prompt
88
+ logits_mask = model.forward_test(image=data_item['image'],
89
+ zoomed_image=data_item['zoom_out_image'],
90
+ # point_prompt_group=[point_prompt, point_prompt_map],
91
+ bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
92
+ text_prompt=text_prompt,
93
+ use_zoom=False
94
+ )
95
+
96
+ # cal dice score
97
+ dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx])
98
+ print(dice)
99
+
100
+ # save prediction as nii.gz file
101
+ save_path='./Case_preds_00001.nii.gz'
102
+ model.processor.save_preds(ct_path, save_path, logits_mask[0][0],
103
+ start_coord=data_item['foreground_start_coord'],
104
+ end_coord=data_item['foreground_end_coord'])
105
+ print('done')
106
  ```
107
 
108
  ### Training script