yuxin commited on
Commit
d0601bd
1 Parent(s): 6f1b94e

add config

Browse files
Files changed (2) hide show
  1. config.json +0 -1
  2. model_segvol_single.py +28 -36
config.json CHANGED
@@ -6,7 +6,6 @@
6
  "AutoConfig": "model_segvol_single.SegVolConfig",
7
  "AutoModel": "model_segvol_single.SegVolModel"
8
  },
9
- "custom_device": "cpu",
10
  "model_type": "segvol",
11
  "patch_size": [
12
  4,
 
6
  "AutoConfig": "model_segvol_single.SegVolConfig",
7
  "AutoModel": "model_segvol_single.SegVolModel"
8
  },
 
9
  "model_type": "segvol",
10
  "patch_size": [
11
  4,
model_segvol_single.py CHANGED
@@ -9,15 +9,11 @@ class SegVolConfig(PretrainedConfig):
9
  def __init__(
10
  self,
11
  test_mode=True,
12
- custom_device='cpu',
13
- # clip_model='.',
14
  **kwargs,
15
  ):
16
  self.spatial_size = [32, 256, 256]
17
  self.patch_size = [4, 16, 16]
18
  self.test_mode = test_mode
19
- self.custom_device = custom_device
20
- # self.clip_model = clip_model
21
  super().__init__(**kwargs)
22
 
23
  class SegVolModel(PreTrainedModel):
@@ -38,14 +34,11 @@ class SegVolModel(PreTrainedModel):
38
  prompt_encoder=sam_model.prompt_encoder,
39
  roi_size=self.config.spatial_size,
40
  patch_size=self.config.patch_size,
41
- custom_device=self.config.custom_device,
42
  # clip_model=self.config.clip_model,
43
  test_mode=self.config.test_mode,
44
  )
45
 
46
- self.processor = SegVolProcessor(spatial_size=self.config.spatial_size, custom_device=self.config.custom_device)
47
-
48
- self.custom_device = self.config.custom_device
49
 
50
  def forward_test(self,
51
  image,
@@ -53,7 +46,8 @@ class SegVolModel(PreTrainedModel):
53
  text_prompt=None,
54
  bbox_prompt_group=None,
55
  point_prompt_group=None,
56
- use_zoom=True):
 
57
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
58
  assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
59
  bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
@@ -110,7 +104,7 @@ class SegVolModel(PreTrainedModel):
110
  ## inference
111
  with torch.no_grad():
112
  logits_single_cropped = sliding_window_inference(
113
- image_single_cropped.to(self.custom_device), prompt_reflection,
114
  self.config.spatial_size, 1, self.model, 0.5,
115
  text=text_prompt,
116
  use_box=bbox_prompt is not None,
@@ -128,7 +122,7 @@ class SegVolModel(PreTrainedModel):
128
 
129
  # processor
130
  class SegVolProcessor():
131
- def __init__(self, spatial_size, custom_device) -> None:
132
  self.img_loader = transforms.LoadImage()
133
  self.transform4test = transforms.Compose(
134
  [
@@ -140,7 +134,6 @@ class SegVolProcessor():
140
  ]
141
  )
142
  self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
143
- self.custom_device = custom_device
144
  self.transform4train = transforms.Compose(
145
  [
146
  # transforms.AddChanneld(keys=["image"]),
@@ -217,24 +210,24 @@ class SegVolProcessor():
217
  item['zoom_out_label'] = item_zoom_out['label']
218
  return item
219
 
220
- def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0):
221
  point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
222
- points_single = (point.unsqueeze(0).float().to(self.custom_device), point_label.unsqueeze(0).float().to(self.custom_device))
223
  binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
224
  return points_single, binary_points_resize
225
 
226
- def bbox_prompt_b(self, label_single_resize):
227
- box_single = generate_box(label_single_resize).unsqueeze(0).float().to(self.custom_device)
228
  binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
229
  return box_single, binary_cube_resize
230
 
231
- def dice_score(self, preds, labels):
232
  assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
233
  predict = preds.view(1, -1)
234
  target = labels.view(1, -1)
235
  if target.shape[1] < 1e8:
236
- predict = predict.to(self.custom_device)
237
- target = target.to(self.custom_device)
238
  predict = torch.sigmoid(predict)
239
  predict = torch.where(predict > 0.5, 1., 0.)
240
 
@@ -425,20 +418,18 @@ class SegVol(nn.Module):
425
  prompt_encoder,
426
  roi_size,
427
  patch_size,
428
- custom_device,
429
  # clip_model,
430
  test_mode=False,
431
  ):
432
  super().__init__()
433
- self.custom_device = custom_device
434
  self.image_encoder = image_encoder
435
  self.mask_decoder = mask_decoder
436
  self.prompt_encoder = prompt_encoder
437
- self.text_encoder = TextEncoder(custom_device=custom_device)
438
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
439
  self.test_mode = test_mode
440
- self.dice_loss = BinaryDiceLoss().to(self.custom_device)
441
- self.bce_loss = BCELoss().to(self.custom_device)
442
  self.decoder_iter = 6
443
 
444
  def forward(self, image, text=None, boxes=None, points=None, **kwargs):
@@ -459,12 +450,13 @@ class SegVol(nn.Module):
459
  return sl_loss
460
 
461
  def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
 
462
  with torch.no_grad():
463
  if boxes is not None:
464
  if len(boxes.shape) == 2:
465
  boxes = boxes[:, None, :] # (B, 1, 6)
466
  if text is not None:
467
- text_embedding = self.text_encoder(text) # (B, 768)
468
  else:
469
  text_embedding = None
470
  sparse_embeddings, dense_embeddings = self.prompt_encoder(
@@ -487,7 +479,8 @@ class SegVol(nn.Module):
487
  return logits
488
 
489
  def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
490
- iter_points, iter_bboxes, iter_organs = self.build_prompt_label(image.shape[0], training_organs, train_labels)
 
491
  # select prompt
492
  prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
493
  [None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
@@ -517,7 +510,7 @@ class SegVol(nn.Module):
517
  # sll_loss += sll_loss_dice + sll_loss_bce
518
  # return sll_loss
519
 
520
- def build_prompt_label(self, bs, training_organs, train_labels):
521
  # generate prompt & label
522
  iter_organs = []
523
  iter_bboxes = []
@@ -541,10 +534,10 @@ class SegVol(nn.Module):
541
  iter_points_ax.append(point)
542
  iter_point_labels.append(point_label)
543
  # batched prompt
544
- iter_points_ax = torch.stack(iter_points_ax, dim=0).to(self.custom_device)
545
- iter_point_labels = torch.stack(iter_point_labels, dim=0).to(self.custom_device)
546
  iter_points = (iter_points_ax, iter_point_labels)
547
- iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
548
  return iter_points, iter_bboxes, iter_organs
549
 
550
  # def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
@@ -611,9 +604,8 @@ class SegVol(nn.Module):
611
  # return pseudo_labels, bboxes
612
 
613
  class TextEncoder(nn.Module):
614
- def __init__(self, custom_device):
615
  super().__init__()
616
- self.custom_device = custom_device
617
  config = CLIPTextConfig()
618
  self.clip_text_model = CLIPTextModel(config)
619
  self.tokenizer = None
@@ -622,20 +614,20 @@ class TextEncoder(nn.Module):
622
  for param in self.clip_text_model.parameters():
623
  param.requires_grad = False
624
 
625
- def organ2tokens(self, organ_names):
626
  text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
627
  tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
628
  for key in tokens.keys():
629
- tokens[key] = tokens[key].to(self.custom_device)
630
  return tokens
631
 
632
- def forward(self, text):
633
  if text is None:
634
  return None
635
  if type(text) is str:
636
  # text is supposed to be list
637
  text = [text]
638
- tokens = self.organ2tokens(text)
639
  clip_outputs = self.clip_text_model(**tokens)
640
  text_embedding = clip_outputs.pooler_output
641
  text_embedding = self.dim_align(text_embedding)
 
9
  def __init__(
10
  self,
11
  test_mode=True,
 
 
12
  **kwargs,
13
  ):
14
  self.spatial_size = [32, 256, 256]
15
  self.patch_size = [4, 16, 16]
16
  self.test_mode = test_mode
 
 
17
  super().__init__(**kwargs)
18
 
19
  class SegVolModel(PreTrainedModel):
 
34
  prompt_encoder=sam_model.prompt_encoder,
35
  roi_size=self.config.spatial_size,
36
  patch_size=self.config.patch_size,
 
37
  # clip_model=self.config.clip_model,
38
  test_mode=self.config.test_mode,
39
  )
40
 
41
+ self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
 
 
42
 
43
  def forward_test(self,
44
  image,
 
46
  text_prompt=None,
47
  bbox_prompt_group=None,
48
  point_prompt_group=None,
49
+ use_zoom=True,):
50
+ device = image.device
51
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
52
  assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
53
  bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
 
104
  ## inference
105
  with torch.no_grad():
106
  logits_single_cropped = sliding_window_inference(
107
+ image_single_cropped.to(device), prompt_reflection,
108
  self.config.spatial_size, 1, self.model, 0.5,
109
  text=text_prompt,
110
  use_box=bbox_prompt is not None,
 
122
 
123
  # processor
124
  class SegVolProcessor():
125
+ def __init__(self, spatial_size) -> None:
126
  self.img_loader = transforms.LoadImage()
127
  self.transform4test = transforms.Compose(
128
  [
 
134
  ]
135
  )
136
  self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
 
137
  self.transform4train = transforms.Compose(
138
  [
139
  # transforms.AddChanneld(keys=["image"]),
 
210
  item['zoom_out_label'] = item_zoom_out['label']
211
  return item
212
 
213
+ def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0, device='cpu'):
214
  point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
215
+ points_single = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
216
  binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
217
  return points_single, binary_points_resize
218
 
219
+ def bbox_prompt_b(self, label_single_resize, device='cpu'):
220
+ box_single = generate_box(label_single_resize).unsqueeze(0).float().to(device)
221
  binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
222
  return box_single, binary_cube_resize
223
 
224
+ def dice_score(self, preds, labels, device='cpu'):
225
  assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
226
  predict = preds.view(1, -1)
227
  target = labels.view(1, -1)
228
  if target.shape[1] < 1e8:
229
+ predict = predict.to(device)
230
+ target = target.to(device)
231
  predict = torch.sigmoid(predict)
232
  predict = torch.where(predict > 0.5, 1., 0.)
233
 
 
418
  prompt_encoder,
419
  roi_size,
420
  patch_size,
 
421
  # clip_model,
422
  test_mode=False,
423
  ):
424
  super().__init__()
 
425
  self.image_encoder = image_encoder
426
  self.mask_decoder = mask_decoder
427
  self.prompt_encoder = prompt_encoder
428
+ self.text_encoder = TextEncoder()
429
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
430
  self.test_mode = test_mode
431
+ self.dice_loss = BinaryDiceLoss()
432
+ self.bce_loss = BCELoss()
433
  self.decoder_iter = 6
434
 
435
  def forward(self, image, text=None, boxes=None, points=None, **kwargs):
 
450
  return sl_loss
451
 
452
  def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
453
+ device = image_embedding.device
454
  with torch.no_grad():
455
  if boxes is not None:
456
  if len(boxes.shape) == 2:
457
  boxes = boxes[:, None, :] # (B, 1, 6)
458
  if text is not None:
459
+ text_embedding = self.text_encoder(text, device) # (B, 768)
460
  else:
461
  text_embedding = None
462
  sparse_embeddings, dense_embeddings = self.prompt_encoder(
 
479
  return logits
480
 
481
  def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
482
+ device = image_embedding.device
483
+ iter_points, iter_bboxes, iter_organs = self.build_prompt_label(image.shape[0], training_organs, train_labels, device)
484
  # select prompt
485
  prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
486
  [None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
 
510
  # sll_loss += sll_loss_dice + sll_loss_bce
511
  # return sll_loss
512
 
513
+ def build_prompt_label(self, bs, training_organs, train_labels, device):
514
  # generate prompt & label
515
  iter_organs = []
516
  iter_bboxes = []
 
534
  iter_points_ax.append(point)
535
  iter_point_labels.append(point_label)
536
  # batched prompt
537
+ iter_points_ax = torch.stack(iter_points_ax, dim=0).to(device)
538
+ iter_point_labels = torch.stack(iter_point_labels, dim=0).to(device)
539
  iter_points = (iter_points_ax, iter_point_labels)
540
+ iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(device)
541
  return iter_points, iter_bboxes, iter_organs
542
 
543
  # def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
 
604
  # return pseudo_labels, bboxes
605
 
606
  class TextEncoder(nn.Module):
607
+ def __init__(self):
608
  super().__init__()
 
609
  config = CLIPTextConfig()
610
  self.clip_text_model = CLIPTextModel(config)
611
  self.tokenizer = None
 
614
  for param in self.clip_text_model.parameters():
615
  param.requires_grad = False
616
 
617
+ def organ2tokens(self, organ_names, device):
618
  text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
619
  tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
620
  for key in tokens.keys():
621
+ tokens[key] = tokens[key].to(device)
622
  return tokens
623
 
624
+ def forward(self, text, device):
625
  if text is None:
626
  return None
627
  if type(text) is str:
628
  # text is supposed to be list
629
  text = [text]
630
+ tokens = self.organ2tokens(text, device)
631
  clip_outputs = self.clip_text_model(**tokens)
632
  text_embedding = clip_outputs.pooler_output
633
  text_embedding = self.dim_align(text_embedding)