yuxindu commited on
Commit
2568f65
·
verified ·
1 Parent(s): 222046a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -1
README.md CHANGED
@@ -39,7 +39,6 @@ import torch
39
 
40
  # get device
41
  device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
42
- device = 'cpu'
43
 
44
  # load model
45
  # IF you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
@@ -109,3 +108,56 @@ print('done')
109
  ```
110
 
111
  ### Training script
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # get device
41
  device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
 
42
 
43
  # load model
44
  # IF you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
 
108
  ```
109
 
110
  ### Training script
111
+
112
+ ```python
113
+ from transformers import AutoModel, AutoTokenizer
114
+ import torch
115
+
116
+ # get device
117
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "CPU")
118
+
119
+ # load model
120
+ # IF you cannot connect to huggingface.co, you can download the repo and set from_pretrained path as the loacl dir path, replacing "yuxindu/segvol"
121
+ clip_tokenizer = AutoTokenizer.from_pretrained("./segvol")
122
+ model = AutoModel.from_pretrained("./segvol", trust_remote_code=True, test_mode=False)
123
+ model.model.text_encoder.tokenizer = clip_tokenizer
124
+ model.train()
125
+ model.to(device)
126
+ print('model load done')
127
+
128
+ # set case path
129
+ # you can download this case from huggingface yuxindu/segvol files and versions
130
+ ct_path = 'path/to/Case_image_00001_0000.nii.gz'
131
+ gt_path = 'path/to/Case_label_00001.nii.gz'
132
+
133
+ # set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
134
+ categories = ["liver", "kidney", "spleen", "pancreas"]
135
+
136
+ # generate npy data format
137
+ ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
138
+ # IF you have download our 25 processed datasets, you can skip to here with the processed ct_npy, gt_npy files
139
+
140
+ # go through train transform
141
+ data_item = model.processor.train_transform(ct_npy, gt_npy)
142
+
143
+ # training example
144
+ # add batch dim manually
145
+ image, gt3D = data_item["image"].unsqueeze(0).to(device), data_item["label"].unsqueeze(0).to(device) # add batch dim
146
+
147
+ loss_step_avg = 0
148
+ for cls_idx in range(len(categories)):
149
+ # optimizer.zero_grad()
150
+ organs_cls = categories[cls_idx]
151
+ labels_cls = gt3D[:, cls_idx]
152
+ print(image.shape, organs_cls, labels_cls.shape)
153
+ loss = model.forward_train(image, train_organs=organs_cls, train_labels=labels_cls)
154
+ loss_step_avg += loss.item()
155
+ loss.backward()
156
+ # optimizer.step()
157
+
158
+ loss_step_avg /= len(categories)
159
+ print(f'AVG loss {loss_step_avg}')
160
+
161
+ # save ckpt
162
+ model.save_pretrained('./ckpt')
163
+ ```