yuxin commited on
Commit
2f9c026
1 Parent(s): b3d03a4
Files changed (1) hide show
  1. model_segvol_single.py +101 -105
model_segvol_single.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers import PreTrainedModel
2
- from .config_segvol import SegVolConfig
3
  import numpy as np
4
  import monai.transforms as transforms
5
 
@@ -77,17 +77,111 @@ class SegVolProcessor():
77
 
78
  # transform
79
  item = self.transform(item)
80
- print('ready for zoom out')
81
  item_zoom_out = self.zoom_out_transform(item)
82
  item['zoom_out_image'] = item_zoom_out['image']
83
  item['zoom_out_label'] = item_zoom_out['label']
84
- print( 'Zoom_in image shape: ', item['image'].shape,
85
- '\nZoom_in label shape: ', item['label'].shape,
86
- '\nZoom_out image shape: ', item['zoom_out_image'].shape,
87
- '\nZoom_out label shape: ', item['zoom_out_label'].shape,
88
- )
89
  return item
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  class MinMaxNormalization(transforms.Transform):
92
  def __call__(self, data):
93
  d = dict(data)
@@ -767,104 +861,6 @@ def _get_scan_interval(
767
  scan_interval.append(interval if interval > 0 else 1)
768
  return tuple(scan_interval)
769
 
770
-
771
- def generate_box(pred_pre, bbox_shift=None):
772
- meaning_post_label = pred_pre # [h, w, d]
773
- ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
774
- if all(tensor.nelement() == 0 for tensor in ones_idx):
775
- bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
776
- # print(bboxes, bboxes.shape)
777
- return bboxes
778
- min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
779
- max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
780
-
781
-
782
- if bbox_shift is None:
783
- corner_min = []
784
- corner_max = []
785
- shape = meaning_post_label.shape
786
- for coor in min_coords:
787
- coor_ = max(0, coor)
788
- corner_min.append(coor_)
789
- for idx, coor in enumerate(max_coords):
790
- coor_ = min(shape[idx], coor)
791
- corner_max.append(coor_)
792
- corner_min = torch.tensor(corner_min)
793
- corner_max = torch.tensor(corner_max)
794
- return torch.cat((corner_min, corner_max), dim=0)
795
- else:
796
- # add perturbation to bounding box coordinates
797
- corner_min = []
798
- corner_max = []
799
- shape = meaning_post_label.shape
800
- for coor in min_coords:
801
- coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
802
- corner_min.append(coor_)
803
- for idx, coor in enumerate(max_coords):
804
- coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
805
- corner_max.append(coor_)
806
- corner_min = torch.tensor(corner_min)
807
- corner_max = torch.tensor(corner_max)
808
- return torch.cat((corner_min, corner_max), dim=0)
809
-
810
-
811
- def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
812
- spacial_dim = 3
813
- points = torch.zeros((0, 3))
814
- labels = torch.zeros((0))
815
- pos_thred = 0.9
816
- neg_thred = 0.1
817
-
818
- # get pos/net indices
819
- positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
820
- negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
821
-
822
- ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
823
- if all(tmp.nelement() == 0 for tmp in ones_idx):
824
- # all neg
825
- num_positive_extra = 0
826
- selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
827
- points = torch.cat((points, selected_positive_point), dim=0)
828
- labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
829
- else:
830
- # random select a pos point
831
- random_idx = torch.randint(len(positive_indices[0]), (1,))
832
- selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
833
- points = torch.cat((points, selected_positive_point), dim=0)
834
- labels = torch.cat((labels, torch.ones((1))))
835
-
836
- if num_positive_extra > 0:
837
- pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
838
- extra_positive_points = []
839
- for pos_idx in pos_idx_list:
840
- extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
841
- extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
842
- points = torch.cat((points, extra_positive_points), dim=0)
843
- labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
844
-
845
- if num_negative_extra > 0:
846
- neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
847
- extra_negative_points = []
848
- for neg_idx in neg_idx_list:
849
- extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
850
- extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
851
- points = torch.cat((points, extra_negative_points), dim=0)
852
- labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
853
- # print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
854
- # print('==> points ', points.shape, labels)
855
-
856
- if fix_extra_point_num is None:
857
- left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
858
- else:
859
- left_point_num = fix_extra_point_num + 1 - labels.shape[0]
860
-
861
- for _ in range(left_point_num):
862
- ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
863
- points = torch.cat((points, ignore_point), dim=0)
864
- labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
865
-
866
- return (points, labels)
867
-
868
  # build 3D SAM
869
  import torch
870
  import numpy as np
 
1
  from transformers import PreTrainedModel
2
+ from config_segvol import SegVolConfig
3
  import numpy as np
4
  import monai.transforms as transforms
5
 
 
77
 
78
  # transform
79
  item = self.transform(item)
 
80
  item_zoom_out = self.zoom_out_transform(item)
81
  item['zoom_out_image'] = item_zoom_out['image']
82
  item['zoom_out_label'] = item_zoom_out['label']
83
+ del item['image_transforms']
84
+ del item['label_transforms']
 
 
 
85
  return item
86
 
87
+
88
+ def generate_box(self, pred_pre, bbox_shift=None):
89
+ meaning_post_label = pred_pre # [h, w, d]
90
+ ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
91
+ if all(tensor.nelement() == 0 for tensor in ones_idx):
92
+ bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
93
+ # print(bboxes, bboxes.shape)
94
+ return bboxes
95
+ min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
96
+ max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
97
+
98
+
99
+ if bbox_shift is None:
100
+ corner_min = []
101
+ corner_max = []
102
+ shape = meaning_post_label.shape
103
+ for coor in min_coords:
104
+ coor_ = max(0, coor)
105
+ corner_min.append(coor_)
106
+ for idx, coor in enumerate(max_coords):
107
+ coor_ = min(shape[idx], coor)
108
+ corner_max.append(coor_)
109
+ corner_min = torch.tensor(corner_min)
110
+ corner_max = torch.tensor(corner_max)
111
+ return torch.cat((corner_min, corner_max), dim=0)
112
+ else:
113
+ # add perturbation to bounding box coordinates
114
+ corner_min = []
115
+ corner_max = []
116
+ shape = meaning_post_label.shape
117
+ for coor in min_coords:
118
+ coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
119
+ corner_min.append(coor_)
120
+ for idx, coor in enumerate(max_coords):
121
+ coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
122
+ corner_max.append(coor_)
123
+ corner_min = torch.tensor(corner_min)
124
+ corner_max = torch.tensor(corner_max)
125
+ return torch.cat((corner_min, corner_max), dim=0)
126
+
127
+
128
+ def select_points(self, preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
129
+ spacial_dim = 3
130
+ points = torch.zeros((0, 3))
131
+ labels = torch.zeros((0))
132
+ pos_thred = 0.9
133
+ neg_thred = 0.1
134
+
135
+ # get pos/net indices
136
+ positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
137
+ negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
138
+
139
+ ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
140
+ if all(tmp.nelement() == 0 for tmp in ones_idx):
141
+ # all neg
142
+ num_positive_extra = 0
143
+ selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
144
+ points = torch.cat((points, selected_positive_point), dim=0)
145
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
146
+ else:
147
+ # random select a pos point
148
+ random_idx = torch.randint(len(positive_indices[0]), (1,))
149
+ selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
150
+ points = torch.cat((points, selected_positive_point), dim=0)
151
+ labels = torch.cat((labels, torch.ones((1))))
152
+
153
+ if num_positive_extra > 0:
154
+ pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
155
+ extra_positive_points = []
156
+ for pos_idx in pos_idx_list:
157
+ extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
158
+ extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
159
+ points = torch.cat((points, extra_positive_points), dim=0)
160
+ labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
161
+
162
+ if num_negative_extra > 0:
163
+ neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
164
+ extra_negative_points = []
165
+ for neg_idx in neg_idx_list:
166
+ extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
167
+ extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
168
+ points = torch.cat((points, extra_negative_points), dim=0)
169
+ labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
170
+ # print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
171
+ # print('==> points ', points.shape, labels)
172
+
173
+ if fix_extra_point_num is None:
174
+ left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
175
+ else:
176
+ left_point_num = fix_extra_point_num + 1 - labels.shape[0]
177
+
178
+ for _ in range(left_point_num):
179
+ ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
180
+ points = torch.cat((points, ignore_point), dim=0)
181
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
182
+
183
+ return (points, labels)
184
+
185
  class MinMaxNormalization(transforms.Transform):
186
  def __call__(self, data):
187
  d = dict(data)
 
861
  scan_interval.append(interval if interval > 0 else 1)
862
  return tuple(scan_interval)
863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
  # build 3D SAM
865
  import torch
866
  import numpy as np