yuxin commited on
Commit
6e933c4
1 Parent(s): da33dfe

add config

Browse files
config.json CHANGED
@@ -6,7 +6,6 @@
6
  "AutoConfig": "model_segvol_single.SegVolConfig",
7
  "AutoModel": "model_segvol_single.SegVolModel"
8
  },
9
- "clip_model": "openai/clip-vit-base-patch32",
10
  "model_type": "segvol",
11
  "patch_size": [
12
  4,
 
6
  "AutoConfig": "model_segvol_single.SegVolConfig",
7
  "AutoModel": "model_segvol_single.SegVolModel"
8
  },
 
9
  "model_type": "segvol",
10
  "patch_size": [
11
  4,
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model_segvol_single.py CHANGED
@@ -9,13 +9,13 @@ class SegVolConfig(PretrainedConfig):
9
  def __init__(
10
  self,
11
  test_mode=True,
12
- clip_model='openai/clip-vit-base-patch32',
13
  **kwargs,
14
  ):
15
  self.spatial_size = [32, 256, 256]
16
  self.patch_size = [4, 16, 16]
17
  self.test_mode = test_mode
18
- self.clip_model = clip_model
19
  super().__init__(**kwargs)
20
 
21
  class SegVolModel(PreTrainedModel):
@@ -36,7 +36,7 @@ class SegVolModel(PreTrainedModel):
36
  prompt_encoder=sam_model.prompt_encoder,
37
  roi_size=self.config.spatial_size,
38
  patch_size=self.config.patch_size,
39
- clip_model=self.config.clip_model,
40
  test_mode=self.config.test_mode,
41
  )
42
 
@@ -118,7 +118,6 @@ class SegVolModel(PreTrainedModel):
118
  return logits_global_single
119
 
120
  def forward_train(self, image, train_organs, train_labels):
121
- print('in forward_train')
122
  loss = self.model(image, text=None, boxes=None, points=None,
123
  train_organs=train_organs,
124
  train_labels=train_labels)
@@ -318,7 +317,6 @@ def generate_box(pred_pre, bbox_shift=None):
318
  ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
319
  if all(tensor.nelement() == 0 for tensor in ones_idx):
320
  bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
321
- # print(bboxes, bboxes.shape)
322
  return bboxes
323
  min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
324
  max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
@@ -395,8 +393,6 @@ def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_p
395
  extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
396
  points = torch.cat((points, extra_negative_points), dim=0)
397
  labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
398
- # print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
399
- # print('==> points ', points.shape, labels)
400
 
401
  if fix_extra_point_num is None:
402
  left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
@@ -415,7 +411,7 @@ import torch
415
  import torch.nn as nn
416
  import torch.nn.functional as F
417
  import numpy as np
418
- from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig
419
  import random
420
 
421
  #%% set up model
@@ -426,7 +422,7 @@ class SegVol(nn.Module):
426
  prompt_encoder,
427
  roi_size,
428
  patch_size,
429
- clip_model,
430
  test_mode=False,
431
  ):
432
  super().__init__()
@@ -434,7 +430,7 @@ class SegVol(nn.Module):
434
  self.image_encoder = image_encoder
435
  self.mask_decoder = mask_decoder
436
  self.prompt_encoder = prompt_encoder
437
- self.text_encoder = TextEncoder(clip_model)
438
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
439
  self.test_mode = test_mode
440
  self.dice_loss = BinaryDiceLoss().to(self.custom_device)
@@ -453,7 +449,6 @@ class SegVol(nn.Module):
453
 
454
  # train mode
455
  ## sl
456
- print('supervised_forward ready')
457
  sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
458
  ## ssl
459
  # ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
@@ -612,12 +607,12 @@ class SegVol(nn.Module):
612
  # return pseudo_labels, bboxes
613
 
614
  class TextEncoder(nn.Module):
615
- def __init__(self, clip_model):
616
  super().__init__()
617
  self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
618
  config = CLIPTextConfig()
619
  self.clip_text_model = CLIPTextModel(config)
620
- self.tokenizer = AutoTokenizer.from_pretrained(clip_model)
621
  self.dim_align = nn.Linear(512, 768)
622
  # freeze text encoder
623
  for param in self.clip_text_model.parameters():
 
9
  def __init__(
10
  self,
11
  test_mode=True,
12
+ # clip_model='.',
13
  **kwargs,
14
  ):
15
  self.spatial_size = [32, 256, 256]
16
  self.patch_size = [4, 16, 16]
17
  self.test_mode = test_mode
18
+ # self.clip_model = clip_model
19
  super().__init__(**kwargs)
20
 
21
  class SegVolModel(PreTrainedModel):
 
36
  prompt_encoder=sam_model.prompt_encoder,
37
  roi_size=self.config.spatial_size,
38
  patch_size=self.config.patch_size,
39
+ # clip_model=self.config.clip_model,
40
  test_mode=self.config.test_mode,
41
  )
42
 
 
118
  return logits_global_single
119
 
120
  def forward_train(self, image, train_organs, train_labels):
 
121
  loss = self.model(image, text=None, boxes=None, points=None,
122
  train_organs=train_organs,
123
  train_labels=train_labels)
 
317
  ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
318
  if all(tensor.nelement() == 0 for tensor in ones_idx):
319
  bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
 
320
  return bboxes
321
  min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
322
  max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
 
393
  extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
394
  points = torch.cat((points, extra_negative_points), dim=0)
395
  labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
 
 
396
 
397
  if fix_extra_point_num is None:
398
  left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
 
411
  import torch.nn as nn
412
  import torch.nn.functional as F
413
  import numpy as np
414
+ from transformers import CLIPTextModel, CLIPTextConfig
415
  import random
416
 
417
  #%% set up model
 
422
  prompt_encoder,
423
  roi_size,
424
  patch_size,
425
+ # clip_model,
426
  test_mode=False,
427
  ):
428
  super().__init__()
 
430
  self.image_encoder = image_encoder
431
  self.mask_decoder = mask_decoder
432
  self.prompt_encoder = prompt_encoder
433
+ self.text_encoder = TextEncoder()
434
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
435
  self.test_mode = test_mode
436
  self.dice_loss = BinaryDiceLoss().to(self.custom_device)
 
449
 
450
  # train mode
451
  ## sl
 
452
  sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
453
  ## ssl
454
  # ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
 
607
  # return pseudo_labels, bboxes
608
 
609
  class TextEncoder(nn.Module):
610
+ def __init__(self):
611
  super().__init__()
612
  self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
613
  config = CLIPTextConfig()
614
  self.clip_text_model = CLIPTextModel(config)
615
+ self.tokenizer = None
616
  self.dim_align = nn.Linear(512, 768)
617
  # freeze text encoder
618
  for param in self.clip_text_model.parameters():
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip", "special_tokens_map_file": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip/special_tokens_map.json", "tokenizer_class": "CLIPTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff