Update README.md
Browse files
README.md
CHANGED
@@ -39,7 +39,6 @@ import torch
|
|
39 |
|
40 |
# get device
|
41 |
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
42 |
-
device = 'cpu'
|
43 |
|
44 |
# load model
|
45 |
# IF you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
|
@@ -109,3 +108,56 @@ print('done')
|
|
109 |
```
|
110 |
|
111 |
### Training script
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# get device
|
41 |
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
|
|
42 |
|
43 |
# load model
|
44 |
# IF you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
|
|
|
108 |
```
|
109 |
|
110 |
### Training script
|
111 |
+
|
112 |
+
```python
|
113 |
+
from transformers import AutoModel, AutoTokenizer
|
114 |
+
import torch
|
115 |
+
|
116 |
+
# get device
|
117 |
+
device = torch.device("cuda:1" if torch.cuda.is_available() else "CPU")
|
118 |
+
|
119 |
+
# load model
|
120 |
+
# IF you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
|
121 |
+
clip_tokenizer = AutoTokenizer.from_pretrained("./segvol")
|
122 |
+
model = AutoModel.from_pretrained("./segvol", trust_remote_code=True, test_mode=False)
|
123 |
+
model.model.text_encoder.tokenizer = clip_tokenizer
|
124 |
+
model.train()
|
125 |
+
model.to(device)
|
126 |
+
print('model load done')
|
127 |
+
|
128 |
+
# set case path
|
129 |
+
# you can download this case from huggingface yuxindu/segvol files and versions
|
130 |
+
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
|
131 |
+
gt_path = 'path/to/Case_label_00001.nii.gz'
|
132 |
+
|
133 |
+
# set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
|
134 |
+
categories = ["liver", "kidney", "spleen", "pancreas"]
|
135 |
+
|
136 |
+
# generate npy data format
|
137 |
+
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
|
138 |
+
# IF you have download our 25 processed datasets, you can skip to here with the processed ct_npy, gt_npy files
|
139 |
+
|
140 |
+
# go through train transform
|
141 |
+
data_item = model.processor.train_transform(ct_npy, gt_npy)
|
142 |
+
|
143 |
+
# training example
|
144 |
+
# add batch dim manually
|
145 |
+
image, gt3D = data_item["image"].unsqueeze(0).to(device), data_item["label"].unsqueeze(0).to(device) # add batch dim
|
146 |
+
|
147 |
+
loss_step_avg = 0
|
148 |
+
for cls_idx in range(len(categories)):
|
149 |
+
# optimizer.zero_grad()
|
150 |
+
organs_cls = categories[cls_idx]
|
151 |
+
labels_cls = gt3D[:, cls_idx]
|
152 |
+
print(image.shape, organs_cls, labels_cls.shape)
|
153 |
+
loss = model.forward_train(image, train_organs=organs_cls, train_labels=labels_cls)
|
154 |
+
loss_step_avg += loss.item()
|
155 |
+
loss.backward()
|
156 |
+
# optimizer.step()
|
157 |
+
|
158 |
+
loss_step_avg /= len(categories)
|
159 |
+
print(f'AVG loss {loss_step_avg}')
|
160 |
+
|
161 |
+
# save ckpt
|
162 |
+
model.save_pretrained('./ckpt')
|
163 |
+
```
|