yuxin
commited on
Commit
•
6f1b94e
1
Parent(s):
6e933c4
add config
Browse files- config.json +1 -0
- model_segvol_single.py +12 -8
config.json
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
8 |
},
|
|
|
9 |
"model_type": "segvol",
|
10 |
"patch_size": [
|
11 |
4,
|
|
|
6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
8 |
},
|
9 |
+
"custom_device": "cpu",
|
10 |
"model_type": "segvol",
|
11 |
"patch_size": [
|
12 |
4,
|
model_segvol_single.py
CHANGED
@@ -9,12 +9,14 @@ class SegVolConfig(PretrainedConfig):
|
|
9 |
def __init__(
|
10 |
self,
|
11 |
test_mode=True,
|
|
|
12 |
# clip_model='.',
|
13 |
**kwargs,
|
14 |
):
|
15 |
self.spatial_size = [32, 256, 256]
|
16 |
self.patch_size = [4, 16, 16]
|
17 |
self.test_mode = test_mode
|
|
|
18 |
# self.clip_model = clip_model
|
19 |
super().__init__(**kwargs)
|
20 |
|
@@ -36,13 +38,14 @@ class SegVolModel(PreTrainedModel):
|
|
36 |
prompt_encoder=sam_model.prompt_encoder,
|
37 |
roi_size=self.config.spatial_size,
|
38 |
patch_size=self.config.patch_size,
|
|
|
39 |
# clip_model=self.config.clip_model,
|
40 |
test_mode=self.config.test_mode,
|
41 |
)
|
42 |
|
43 |
-
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
|
44 |
|
45 |
-
self.custom_device =
|
46 |
|
47 |
def forward_test(self,
|
48 |
image,
|
@@ -125,7 +128,7 @@ class SegVolModel(PreTrainedModel):
|
|
125 |
|
126 |
# processor
|
127 |
class SegVolProcessor():
|
128 |
-
def __init__(self, spatial_size) -> None:
|
129 |
self.img_loader = transforms.LoadImage()
|
130 |
self.transform4test = transforms.Compose(
|
131 |
[
|
@@ -137,7 +140,7 @@ class SegVolProcessor():
|
|
137 |
]
|
138 |
)
|
139 |
self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
|
140 |
-
self.custom_device =
|
141 |
self.transform4train = transforms.Compose(
|
142 |
[
|
143 |
# transforms.AddChanneld(keys=["image"]),
|
@@ -422,15 +425,16 @@ class SegVol(nn.Module):
|
|
422 |
prompt_encoder,
|
423 |
roi_size,
|
424 |
patch_size,
|
|
|
425 |
# clip_model,
|
426 |
test_mode=False,
|
427 |
):
|
428 |
super().__init__()
|
429 |
-
self.custom_device =
|
430 |
self.image_encoder = image_encoder
|
431 |
self.mask_decoder = mask_decoder
|
432 |
self.prompt_encoder = prompt_encoder
|
433 |
-
self.text_encoder = TextEncoder()
|
434 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
435 |
self.test_mode = test_mode
|
436 |
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
@@ -607,9 +611,9 @@ class SegVol(nn.Module):
|
|
607 |
# return pseudo_labels, bboxes
|
608 |
|
609 |
class TextEncoder(nn.Module):
|
610 |
-
def __init__(self):
|
611 |
super().__init__()
|
612 |
-
self.custom_device =
|
613 |
config = CLIPTextConfig()
|
614 |
self.clip_text_model = CLIPTextModel(config)
|
615 |
self.tokenizer = None
|
|
|
9 |
def __init__(
|
10 |
self,
|
11 |
test_mode=True,
|
12 |
+
custom_device='cpu',
|
13 |
# clip_model='.',
|
14 |
**kwargs,
|
15 |
):
|
16 |
self.spatial_size = [32, 256, 256]
|
17 |
self.patch_size = [4, 16, 16]
|
18 |
self.test_mode = test_mode
|
19 |
+
self.custom_device = custom_device
|
20 |
# self.clip_model = clip_model
|
21 |
super().__init__(**kwargs)
|
22 |
|
|
|
38 |
prompt_encoder=sam_model.prompt_encoder,
|
39 |
roi_size=self.config.spatial_size,
|
40 |
patch_size=self.config.patch_size,
|
41 |
+
custom_device=self.config.custom_device,
|
42 |
# clip_model=self.config.clip_model,
|
43 |
test_mode=self.config.test_mode,
|
44 |
)
|
45 |
|
46 |
+
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size, custom_device=self.config.custom_device)
|
47 |
|
48 |
+
self.custom_device = self.config.custom_device
|
49 |
|
50 |
def forward_test(self,
|
51 |
image,
|
|
|
128 |
|
129 |
# processor
|
130 |
class SegVolProcessor():
|
131 |
+
def __init__(self, spatial_size, custom_device) -> None:
|
132 |
self.img_loader = transforms.LoadImage()
|
133 |
self.transform4test = transforms.Compose(
|
134 |
[
|
|
|
140 |
]
|
141 |
)
|
142 |
self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
|
143 |
+
self.custom_device = custom_device
|
144 |
self.transform4train = transforms.Compose(
|
145 |
[
|
146 |
# transforms.AddChanneld(keys=["image"]),
|
|
|
425 |
prompt_encoder,
|
426 |
roi_size,
|
427 |
patch_size,
|
428 |
+
custom_device,
|
429 |
# clip_model,
|
430 |
test_mode=False,
|
431 |
):
|
432 |
super().__init__()
|
433 |
+
self.custom_device = custom_device
|
434 |
self.image_encoder = image_encoder
|
435 |
self.mask_decoder = mask_decoder
|
436 |
self.prompt_encoder = prompt_encoder
|
437 |
+
self.text_encoder = TextEncoder(custom_device=custom_device)
|
438 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
439 |
self.test_mode = test_mode
|
440 |
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
|
|
611 |
# return pseudo_labels, bboxes
|
612 |
|
613 |
class TextEncoder(nn.Module):
|
614 |
+
def __init__(self, custom_device):
|
615 |
super().__init__()
|
616 |
+
self.custom_device = custom_device
|
617 |
config = CLIPTextConfig()
|
618 |
self.clip_text_model = CLIPTextModel(config)
|
619 |
self.tokenizer = None
|