yuxin commited on
Commit
6f1b94e
1 Parent(s): 6e933c4

add config

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. model_segvol_single.py +12 -8
config.json CHANGED
@@ -6,6 +6,7 @@
6
  "AutoConfig": "model_segvol_single.SegVolConfig",
7
  "AutoModel": "model_segvol_single.SegVolModel"
8
  },
 
9
  "model_type": "segvol",
10
  "patch_size": [
11
  4,
 
6
  "AutoConfig": "model_segvol_single.SegVolConfig",
7
  "AutoModel": "model_segvol_single.SegVolModel"
8
  },
9
+ "custom_device": "cpu",
10
  "model_type": "segvol",
11
  "patch_size": [
12
  4,
model_segvol_single.py CHANGED
@@ -9,12 +9,14 @@ class SegVolConfig(PretrainedConfig):
9
  def __init__(
10
  self,
11
  test_mode=True,
 
12
  # clip_model='.',
13
  **kwargs,
14
  ):
15
  self.spatial_size = [32, 256, 256]
16
  self.patch_size = [4, 16, 16]
17
  self.test_mode = test_mode
 
18
  # self.clip_model = clip_model
19
  super().__init__(**kwargs)
20
 
@@ -36,13 +38,14 @@ class SegVolModel(PreTrainedModel):
36
  prompt_encoder=sam_model.prompt_encoder,
37
  roi_size=self.config.spatial_size,
38
  patch_size=self.config.patch_size,
 
39
  # clip_model=self.config.clip_model,
40
  test_mode=self.config.test_mode,
41
  )
42
 
43
- self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
44
 
45
- self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
 
47
  def forward_test(self,
48
  image,
@@ -125,7 +128,7 @@ class SegVolModel(PreTrainedModel):
125
 
126
  # processor
127
  class SegVolProcessor():
128
- def __init__(self, spatial_size) -> None:
129
  self.img_loader = transforms.LoadImage()
130
  self.transform4test = transforms.Compose(
131
  [
@@ -137,7 +140,7 @@ class SegVolProcessor():
137
  ]
138
  )
139
  self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
140
- self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
141
  self.transform4train = transforms.Compose(
142
  [
143
  # transforms.AddChanneld(keys=["image"]),
@@ -422,15 +425,16 @@ class SegVol(nn.Module):
422
  prompt_encoder,
423
  roi_size,
424
  patch_size,
 
425
  # clip_model,
426
  test_mode=False,
427
  ):
428
  super().__init__()
429
- self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
430
  self.image_encoder = image_encoder
431
  self.mask_decoder = mask_decoder
432
  self.prompt_encoder = prompt_encoder
433
- self.text_encoder = TextEncoder()
434
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
435
  self.test_mode = test_mode
436
  self.dice_loss = BinaryDiceLoss().to(self.custom_device)
@@ -607,9 +611,9 @@ class SegVol(nn.Module):
607
  # return pseudo_labels, bboxes
608
 
609
  class TextEncoder(nn.Module):
610
- def __init__(self):
611
  super().__init__()
612
- self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
613
  config = CLIPTextConfig()
614
  self.clip_text_model = CLIPTextModel(config)
615
  self.tokenizer = None
 
9
  def __init__(
10
  self,
11
  test_mode=True,
12
+ custom_device='cpu',
13
  # clip_model='.',
14
  **kwargs,
15
  ):
16
  self.spatial_size = [32, 256, 256]
17
  self.patch_size = [4, 16, 16]
18
  self.test_mode = test_mode
19
+ self.custom_device = custom_device
20
  # self.clip_model = clip_model
21
  super().__init__(**kwargs)
22
 
 
38
  prompt_encoder=sam_model.prompt_encoder,
39
  roi_size=self.config.spatial_size,
40
  patch_size=self.config.patch_size,
41
+ custom_device=self.config.custom_device,
42
  # clip_model=self.config.clip_model,
43
  test_mode=self.config.test_mode,
44
  )
45
 
46
+ self.processor = SegVolProcessor(spatial_size=self.config.spatial_size, custom_device=self.config.custom_device)
47
 
48
+ self.custom_device = self.config.custom_device
49
 
50
  def forward_test(self,
51
  image,
 
128
 
129
  # processor
130
  class SegVolProcessor():
131
+ def __init__(self, spatial_size, custom_device) -> None:
132
  self.img_loader = transforms.LoadImage()
133
  self.transform4test = transforms.Compose(
134
  [
 
140
  ]
141
  )
142
  self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
143
+ self.custom_device = custom_device
144
  self.transform4train = transforms.Compose(
145
  [
146
  # transforms.AddChanneld(keys=["image"]),
 
425
  prompt_encoder,
426
  roi_size,
427
  patch_size,
428
+ custom_device,
429
  # clip_model,
430
  test_mode=False,
431
  ):
432
  super().__init__()
433
+ self.custom_device = custom_device
434
  self.image_encoder = image_encoder
435
  self.mask_decoder = mask_decoder
436
  self.prompt_encoder = prompt_encoder
437
+ self.text_encoder = TextEncoder(custom_device=custom_device)
438
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
439
  self.test_mode = test_mode
440
  self.dice_loss = BinaryDiceLoss().to(self.custom_device)
 
611
  # return pseudo_labels, bboxes
612
 
613
  class TextEncoder(nn.Module):
614
+ def __init__(self, custom_device):
615
  super().__init__()
616
+ self.custom_device = custom_device
617
  config = CLIPTextConfig()
618
  self.clip_text_model = CLIPTextModel(config)
619
  self.tokenizer = None