yuxin
commited on
Commit
·
c995392
1
Parent(s):
faa61f9
add model
Browse files- config.json +0 -1
- config_segvol.py +0 -2
- model_segvol_single.py +15 -6
config.json
CHANGED
@@ -18,7 +18,6 @@
|
|
18 |
256
|
19 |
],
|
20 |
"test_mode": true,
|
21 |
-
"test_w_zoom": false,
|
22 |
"torch_dtype": "float32",
|
23 |
"transformers_version": "4.18.0"
|
24 |
}
|
|
|
18 |
256
|
19 |
],
|
20 |
"test_mode": true,
|
|
|
21 |
"torch_dtype": "float32",
|
22 |
"transformers_version": "4.18.0"
|
23 |
}
|
config_segvol.py
CHANGED
@@ -6,11 +6,9 @@ class SegVolConfig(PretrainedConfig):
|
|
6 |
def __init__(
|
7 |
self,
|
8 |
test_mode=True,
|
9 |
-
test_w_zoom=False,
|
10 |
**kwargs,
|
11 |
):
|
12 |
self.spatial_size = [32, 256, 256]
|
13 |
self.patch_size = [4, 16, 16]
|
14 |
self.test_mode = test_mode
|
15 |
-
self.test_w_zoom = test_w_zoom
|
16 |
super().__init__(**kwargs)
|
|
|
6 |
def __init__(
|
7 |
self,
|
8 |
test_mode=True,
|
|
|
9 |
**kwargs,
|
10 |
):
|
11 |
self.spatial_size = [32, 256, 256]
|
12 |
self.patch_size = [4, 16, 16]
|
13 |
self.test_mode = test_mode
|
|
|
14 |
super().__init__(**kwargs)
|
model_segvol_single.py
CHANGED
@@ -26,8 +26,16 @@ class SegVolModel(PreTrainedModel):
|
|
26 |
|
27 |
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
|
28 |
|
29 |
-
def forward(self, image,
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# processor
|
33 |
class SegVolProcessor():
|
@@ -89,15 +97,15 @@ class SegVolProcessor():
|
|
89 |
item['zoom_out_label'] = item_zoom_out['label']
|
90 |
return item
|
91 |
|
92 |
-
def
|
93 |
point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
|
94 |
points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
|
95 |
-
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
|
96 |
return points_single, binary_points_resize
|
97 |
|
98 |
-
def
|
99 |
box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
|
100 |
-
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
|
101 |
return box_single, binary_cube_resize
|
102 |
|
103 |
class MinMaxNormalization(transforms.Transform):
|
@@ -461,6 +469,7 @@ class TextEncoder(nn.Module):
|
|
461 |
if text is None:
|
462 |
return None
|
463 |
if type(text) is str:
|
|
|
464 |
text = [text]
|
465 |
tokens = self.organ2tokens(text)
|
466 |
clip_outputs = self.clip_text_model(**tokens)
|
|
|
26 |
|
27 |
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
|
28 |
|
29 |
+
def forward(self, image, zoomed_image=None, text_prompt=None, bbox_prompt=None, point_prompt=None, **kwargs):
|
30 |
+
print(image.shape, zoomed_image.shape, text_prompt)
|
31 |
+
print(bbox_prompt[0].shape, bbox_prompt[1].shape, point_prompt[0].shape, point_prompt[1].shape)
|
32 |
+
# test mode
|
33 |
+
if self.config.test_mode:
|
34 |
+
return
|
35 |
+
else:
|
36 |
+
print('unsupport training mode now')
|
37 |
+
return
|
38 |
+
return self.model.forward(image, text=None, boxes=None, points=None, **kwargs)
|
39 |
|
40 |
# processor
|
41 |
class SegVolProcessor():
|
|
|
97 |
item['zoom_out_label'] = item_zoom_out['label']
|
98 |
return item
|
99 |
|
100 |
+
def point_prompt_b(self, label_single_resize):
|
101 |
point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
|
102 |
points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
|
103 |
+
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
104 |
return points_single, binary_points_resize
|
105 |
|
106 |
+
def bbox_prompt_b(self, label_single_resize):
|
107 |
box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
|
108 |
+
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
109 |
return box_single, binary_cube_resize
|
110 |
|
111 |
class MinMaxNormalization(transforms.Transform):
|
|
|
469 |
if text is None:
|
470 |
return None
|
471 |
if type(text) is str:
|
472 |
+
# text is supposed to be list
|
473 |
text = [text]
|
474 |
tokens = self.organ2tokens(text)
|
475 |
clip_outputs = self.clip_text_model(**tokens)
|