yuxin
commited on
Commit
•
d0601bd
1
Parent(s):
6f1b94e
add config
Browse files- config.json +0 -1
- model_segvol_single.py +28 -36
config.json
CHANGED
@@ -6,7 +6,6 @@
|
|
6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
8 |
},
|
9 |
-
"custom_device": "cpu",
|
10 |
"model_type": "segvol",
|
11 |
"patch_size": [
|
12 |
4,
|
|
|
6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
8 |
},
|
|
|
9 |
"model_type": "segvol",
|
10 |
"patch_size": [
|
11 |
4,
|
model_segvol_single.py
CHANGED
@@ -9,15 +9,11 @@ class SegVolConfig(PretrainedConfig):
|
|
9 |
def __init__(
|
10 |
self,
|
11 |
test_mode=True,
|
12 |
-
custom_device='cpu',
|
13 |
-
# clip_model='.',
|
14 |
**kwargs,
|
15 |
):
|
16 |
self.spatial_size = [32, 256, 256]
|
17 |
self.patch_size = [4, 16, 16]
|
18 |
self.test_mode = test_mode
|
19 |
-
self.custom_device = custom_device
|
20 |
-
# self.clip_model = clip_model
|
21 |
super().__init__(**kwargs)
|
22 |
|
23 |
class SegVolModel(PreTrainedModel):
|
@@ -38,14 +34,11 @@ class SegVolModel(PreTrainedModel):
|
|
38 |
prompt_encoder=sam_model.prompt_encoder,
|
39 |
roi_size=self.config.spatial_size,
|
40 |
patch_size=self.config.patch_size,
|
41 |
-
custom_device=self.config.custom_device,
|
42 |
# clip_model=self.config.clip_model,
|
43 |
test_mode=self.config.test_mode,
|
44 |
)
|
45 |
|
46 |
-
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size
|
47 |
-
|
48 |
-
self.custom_device = self.config.custom_device
|
49 |
|
50 |
def forward_test(self,
|
51 |
image,
|
@@ -53,7 +46,8 @@ class SegVolModel(PreTrainedModel):
|
|
53 |
text_prompt=None,
|
54 |
bbox_prompt_group=None,
|
55 |
point_prompt_group=None,
|
56 |
-
use_zoom=True):
|
|
|
57 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
58 |
assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
|
59 |
bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
|
@@ -110,7 +104,7 @@ class SegVolModel(PreTrainedModel):
|
|
110 |
## inference
|
111 |
with torch.no_grad():
|
112 |
logits_single_cropped = sliding_window_inference(
|
113 |
-
image_single_cropped.to(
|
114 |
self.config.spatial_size, 1, self.model, 0.5,
|
115 |
text=text_prompt,
|
116 |
use_box=bbox_prompt is not None,
|
@@ -128,7 +122,7 @@ class SegVolModel(PreTrainedModel):
|
|
128 |
|
129 |
# processor
|
130 |
class SegVolProcessor():
|
131 |
-
def __init__(self, spatial_size
|
132 |
self.img_loader = transforms.LoadImage()
|
133 |
self.transform4test = transforms.Compose(
|
134 |
[
|
@@ -140,7 +134,6 @@ class SegVolProcessor():
|
|
140 |
]
|
141 |
)
|
142 |
self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
|
143 |
-
self.custom_device = custom_device
|
144 |
self.transform4train = transforms.Compose(
|
145 |
[
|
146 |
# transforms.AddChanneld(keys=["image"]),
|
@@ -217,24 +210,24 @@ class SegVolProcessor():
|
|
217 |
item['zoom_out_label'] = item_zoom_out['label']
|
218 |
return item
|
219 |
|
220 |
-
def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0):
|
221 |
point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
|
222 |
-
points_single = (point.unsqueeze(0).float().to(
|
223 |
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
224 |
return points_single, binary_points_resize
|
225 |
|
226 |
-
def bbox_prompt_b(self, label_single_resize):
|
227 |
-
box_single = generate_box(label_single_resize).unsqueeze(0).float().to(
|
228 |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
229 |
return box_single, binary_cube_resize
|
230 |
|
231 |
-
def dice_score(self, preds, labels):
|
232 |
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
|
233 |
predict = preds.view(1, -1)
|
234 |
target = labels.view(1, -1)
|
235 |
if target.shape[1] < 1e8:
|
236 |
-
predict = predict.to(
|
237 |
-
target = target.to(
|
238 |
predict = torch.sigmoid(predict)
|
239 |
predict = torch.where(predict > 0.5, 1., 0.)
|
240 |
|
@@ -425,20 +418,18 @@ class SegVol(nn.Module):
|
|
425 |
prompt_encoder,
|
426 |
roi_size,
|
427 |
patch_size,
|
428 |
-
custom_device,
|
429 |
# clip_model,
|
430 |
test_mode=False,
|
431 |
):
|
432 |
super().__init__()
|
433 |
-
self.custom_device = custom_device
|
434 |
self.image_encoder = image_encoder
|
435 |
self.mask_decoder = mask_decoder
|
436 |
self.prompt_encoder = prompt_encoder
|
437 |
-
self.text_encoder = TextEncoder(
|
438 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
439 |
self.test_mode = test_mode
|
440 |
-
self.dice_loss = BinaryDiceLoss()
|
441 |
-
self.bce_loss = BCELoss()
|
442 |
self.decoder_iter = 6
|
443 |
|
444 |
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
@@ -459,12 +450,13 @@ class SegVol(nn.Module):
|
|
459 |
return sl_loss
|
460 |
|
461 |
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
|
|
|
462 |
with torch.no_grad():
|
463 |
if boxes is not None:
|
464 |
if len(boxes.shape) == 2:
|
465 |
boxes = boxes[:, None, :] # (B, 1, 6)
|
466 |
if text is not None:
|
467 |
-
text_embedding = self.text_encoder(text) # (B, 768)
|
468 |
else:
|
469 |
text_embedding = None
|
470 |
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
@@ -487,7 +479,8 @@ class SegVol(nn.Module):
|
|
487 |
return logits
|
488 |
|
489 |
def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
|
490 |
-
|
|
|
491 |
# select prompt
|
492 |
prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
|
493 |
[None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
|
@@ -517,7 +510,7 @@ class SegVol(nn.Module):
|
|
517 |
# sll_loss += sll_loss_dice + sll_loss_bce
|
518 |
# return sll_loss
|
519 |
|
520 |
-
def build_prompt_label(self, bs, training_organs, train_labels):
|
521 |
# generate prompt & label
|
522 |
iter_organs = []
|
523 |
iter_bboxes = []
|
@@ -541,10 +534,10 @@ class SegVol(nn.Module):
|
|
541 |
iter_points_ax.append(point)
|
542 |
iter_point_labels.append(point_label)
|
543 |
# batched prompt
|
544 |
-
iter_points_ax = torch.stack(iter_points_ax, dim=0).to(
|
545 |
-
iter_point_labels = torch.stack(iter_point_labels, dim=0).to(
|
546 |
iter_points = (iter_points_ax, iter_point_labels)
|
547 |
-
iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(
|
548 |
return iter_points, iter_bboxes, iter_organs
|
549 |
|
550 |
# def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
|
@@ -611,9 +604,8 @@ class SegVol(nn.Module):
|
|
611 |
# return pseudo_labels, bboxes
|
612 |
|
613 |
class TextEncoder(nn.Module):
|
614 |
-
def __init__(self
|
615 |
super().__init__()
|
616 |
-
self.custom_device = custom_device
|
617 |
config = CLIPTextConfig()
|
618 |
self.clip_text_model = CLIPTextModel(config)
|
619 |
self.tokenizer = None
|
@@ -622,20 +614,20 @@ class TextEncoder(nn.Module):
|
|
622 |
for param in self.clip_text_model.parameters():
|
623 |
param.requires_grad = False
|
624 |
|
625 |
-
def organ2tokens(self, organ_names):
|
626 |
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
627 |
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
628 |
for key in tokens.keys():
|
629 |
-
tokens[key] = tokens[key].to(
|
630 |
return tokens
|
631 |
|
632 |
-
def forward(self, text):
|
633 |
if text is None:
|
634 |
return None
|
635 |
if type(text) is str:
|
636 |
# text is supposed to be list
|
637 |
text = [text]
|
638 |
-
tokens = self.organ2tokens(text)
|
639 |
clip_outputs = self.clip_text_model(**tokens)
|
640 |
text_embedding = clip_outputs.pooler_output
|
641 |
text_embedding = self.dim_align(text_embedding)
|
|
|
9 |
def __init__(
|
10 |
self,
|
11 |
test_mode=True,
|
|
|
|
|
12 |
**kwargs,
|
13 |
):
|
14 |
self.spatial_size = [32, 256, 256]
|
15 |
self.patch_size = [4, 16, 16]
|
16 |
self.test_mode = test_mode
|
|
|
|
|
17 |
super().__init__(**kwargs)
|
18 |
|
19 |
class SegVolModel(PreTrainedModel):
|
|
|
34 |
prompt_encoder=sam_model.prompt_encoder,
|
35 |
roi_size=self.config.spatial_size,
|
36 |
patch_size=self.config.patch_size,
|
|
|
37 |
# clip_model=self.config.clip_model,
|
38 |
test_mode=self.config.test_mode,
|
39 |
)
|
40 |
|
41 |
+
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
|
|
|
|
|
42 |
|
43 |
def forward_test(self,
|
44 |
image,
|
|
|
46 |
text_prompt=None,
|
47 |
bbox_prompt_group=None,
|
48 |
point_prompt_group=None,
|
49 |
+
use_zoom=True,):
|
50 |
+
device = image.device
|
51 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
52 |
assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
|
53 |
bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
|
|
|
104 |
## inference
|
105 |
with torch.no_grad():
|
106 |
logits_single_cropped = sliding_window_inference(
|
107 |
+
image_single_cropped.to(device), prompt_reflection,
|
108 |
self.config.spatial_size, 1, self.model, 0.5,
|
109 |
text=text_prompt,
|
110 |
use_box=bbox_prompt is not None,
|
|
|
122 |
|
123 |
# processor
|
124 |
class SegVolProcessor():
|
125 |
+
def __init__(self, spatial_size) -> None:
|
126 |
self.img_loader = transforms.LoadImage()
|
127 |
self.transform4test = transforms.Compose(
|
128 |
[
|
|
|
134 |
]
|
135 |
)
|
136 |
self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
|
|
|
137 |
self.transform4train = transforms.Compose(
|
138 |
[
|
139 |
# transforms.AddChanneld(keys=["image"]),
|
|
|
210 |
item['zoom_out_label'] = item_zoom_out['label']
|
211 |
return item
|
212 |
|
213 |
+
def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0, device='cpu'):
|
214 |
point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
|
215 |
+
points_single = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
|
216 |
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
217 |
return points_single, binary_points_resize
|
218 |
|
219 |
+
def bbox_prompt_b(self, label_single_resize, device='cpu'):
|
220 |
+
box_single = generate_box(label_single_resize).unsqueeze(0).float().to(device)
|
221 |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
222 |
return box_single, binary_cube_resize
|
223 |
|
224 |
+
def dice_score(self, preds, labels, device='cpu'):
|
225 |
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
|
226 |
predict = preds.view(1, -1)
|
227 |
target = labels.view(1, -1)
|
228 |
if target.shape[1] < 1e8:
|
229 |
+
predict = predict.to(device)
|
230 |
+
target = target.to(device)
|
231 |
predict = torch.sigmoid(predict)
|
232 |
predict = torch.where(predict > 0.5, 1., 0.)
|
233 |
|
|
|
418 |
prompt_encoder,
|
419 |
roi_size,
|
420 |
patch_size,
|
|
|
421 |
# clip_model,
|
422 |
test_mode=False,
|
423 |
):
|
424 |
super().__init__()
|
|
|
425 |
self.image_encoder = image_encoder
|
426 |
self.mask_decoder = mask_decoder
|
427 |
self.prompt_encoder = prompt_encoder
|
428 |
+
self.text_encoder = TextEncoder()
|
429 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
430 |
self.test_mode = test_mode
|
431 |
+
self.dice_loss = BinaryDiceLoss()
|
432 |
+
self.bce_loss = BCELoss()
|
433 |
self.decoder_iter = 6
|
434 |
|
435 |
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
|
|
450 |
return sl_loss
|
451 |
|
452 |
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
|
453 |
+
device = image_embedding.device
|
454 |
with torch.no_grad():
|
455 |
if boxes is not None:
|
456 |
if len(boxes.shape) == 2:
|
457 |
boxes = boxes[:, None, :] # (B, 1, 6)
|
458 |
if text is not None:
|
459 |
+
text_embedding = self.text_encoder(text, device) # (B, 768)
|
460 |
else:
|
461 |
text_embedding = None
|
462 |
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
|
|
479 |
return logits
|
480 |
|
481 |
def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
|
482 |
+
device = image_embedding.device
|
483 |
+
iter_points, iter_bboxes, iter_organs = self.build_prompt_label(image.shape[0], training_organs, train_labels, device)
|
484 |
# select prompt
|
485 |
prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
|
486 |
[None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
|
|
|
510 |
# sll_loss += sll_loss_dice + sll_loss_bce
|
511 |
# return sll_loss
|
512 |
|
513 |
+
def build_prompt_label(self, bs, training_organs, train_labels, device):
|
514 |
# generate prompt & label
|
515 |
iter_organs = []
|
516 |
iter_bboxes = []
|
|
|
534 |
iter_points_ax.append(point)
|
535 |
iter_point_labels.append(point_label)
|
536 |
# batched prompt
|
537 |
+
iter_points_ax = torch.stack(iter_points_ax, dim=0).to(device)
|
538 |
+
iter_point_labels = torch.stack(iter_point_labels, dim=0).to(device)
|
539 |
iter_points = (iter_points_ax, iter_point_labels)
|
540 |
+
iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(device)
|
541 |
return iter_points, iter_bboxes, iter_organs
|
542 |
|
543 |
# def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
|
|
|
604 |
# return pseudo_labels, bboxes
|
605 |
|
606 |
class TextEncoder(nn.Module):
|
607 |
+
def __init__(self):
|
608 |
super().__init__()
|
|
|
609 |
config = CLIPTextConfig()
|
610 |
self.clip_text_model = CLIPTextModel(config)
|
611 |
self.tokenizer = None
|
|
|
614 |
for param in self.clip_text_model.parameters():
|
615 |
param.requires_grad = False
|
616 |
|
617 |
+
def organ2tokens(self, organ_names, device):
|
618 |
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
619 |
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
620 |
for key in tokens.keys():
|
621 |
+
tokens[key] = tokens[key].to(device)
|
622 |
return tokens
|
623 |
|
624 |
+
def forward(self, text, device):
|
625 |
if text is None:
|
626 |
return None
|
627 |
if type(text) is str:
|
628 |
# text is supposed to be list
|
629 |
text = [text]
|
630 |
+
tokens = self.organ2tokens(text, device)
|
631 |
clip_outputs = self.clip_text_model(**tokens)
|
632 |
text_embedding = clip_outputs.pooler_output
|
633 |
text_embedding = self.dim_align(text_embedding)
|