yuxin
commited on
Commit
•
ad9a9f8
1
Parent(s):
4ee7532
add model
Browse files- model_segvol_single.py +4 -6
model_segvol_single.py
CHANGED
@@ -52,9 +52,9 @@ class SegVolModel(PreTrainedModel):
|
|
52 |
bbox_prompt, bbox_prompt_map = bbox_prompt
|
53 |
if point_prompt is not None:
|
54 |
point_prompt, point_prompt_map = point_prompt
|
55 |
-
print(image.shape, zoomed_image.shape, text_prompt)
|
56 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
57 |
-
print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
|
58 |
volume_shape = image[0][0].shape
|
59 |
|
60 |
with torch.no_grad():
|
@@ -62,14 +62,12 @@ class SegVolModel(PreTrainedModel):
|
|
62 |
text=text_prompt,
|
63 |
boxes=bbox_prompt,
|
64 |
points=point_prompt)
|
65 |
-
print(logits_global_single.shape)
|
66 |
logits_global_single = F.interpolate(
|
67 |
logits_global_single.cpu(),
|
68 |
size=volume_shape, mode='nearest')
|
69 |
if not use_zoom:
|
70 |
return logits_global_single
|
71 |
-
|
72 |
-
print(torch.unique(logits_global_single))
|
73 |
if point_prompt_map is not None:
|
74 |
binary_points = F.interpolate(
|
75 |
point_prompt_map.float(),
|
@@ -88,7 +86,7 @@ class SegVolModel(PreTrainedModel):
|
|
88 |
image_single_cropped = image[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
89 |
global_preds = (torch.sigmoid(logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
|
90 |
|
91 |
-
assert not (bbox_prompt is not None and point_prompt is not None)
|
92 |
prompt_reflection = None
|
93 |
if bbox_prompt is not None:
|
94 |
binary_cube_cropped = binary_cube[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
|
|
52 |
bbox_prompt, bbox_prompt_map = bbox_prompt
|
53 |
if point_prompt is not None:
|
54 |
point_prompt, point_prompt_map = point_prompt
|
55 |
+
# print(image.shape, zoomed_image.shape, text_prompt)
|
56 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
57 |
+
# print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
|
58 |
volume_shape = image[0][0].shape
|
59 |
|
60 |
with torch.no_grad():
|
|
|
62 |
text=text_prompt,
|
63 |
boxes=bbox_prompt,
|
64 |
points=point_prompt)
|
|
|
65 |
logits_global_single = F.interpolate(
|
66 |
logits_global_single.cpu(),
|
67 |
size=volume_shape, mode='nearest')
|
68 |
if not use_zoom:
|
69 |
return logits_global_single
|
70 |
+
|
|
|
71 |
if point_prompt_map is not None:
|
72 |
binary_points = F.interpolate(
|
73 |
point_prompt_map.float(),
|
|
|
86 |
image_single_cropped = image[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
87 |
global_preds = (torch.sigmoid(logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
|
88 |
|
89 |
+
assert not (bbox_prompt is not None and point_prompt is not None), 'Do not use point prompt and box prompt at the same time.'
|
90 |
prompt_reflection = None
|
91 |
if bbox_prompt is not None:
|
92 |
binary_cube_cropped = binary_cube[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|