yuxin commited on
Commit
ad9a9f8
1 Parent(s): 4ee7532
Files changed (1) hide show
  1. model_segvol_single.py +4 -6
model_segvol_single.py CHANGED
@@ -52,9 +52,9 @@ class SegVolModel(PreTrainedModel):
52
  bbox_prompt, bbox_prompt_map = bbox_prompt
53
  if point_prompt is not None:
54
  point_prompt, point_prompt_map = point_prompt
55
- print(image.shape, zoomed_image.shape, text_prompt)
56
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
57
- print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
58
  volume_shape = image[0][0].shape
59
 
60
  with torch.no_grad():
@@ -62,14 +62,12 @@ class SegVolModel(PreTrainedModel):
62
  text=text_prompt,
63
  boxes=bbox_prompt,
64
  points=point_prompt)
65
- print(logits_global_single.shape)
66
  logits_global_single = F.interpolate(
67
  logits_global_single.cpu(),
68
  size=volume_shape, mode='nearest')
69
  if not use_zoom:
70
  return logits_global_single
71
- print(logits_global_single.shape)
72
- print(torch.unique(logits_global_single))
73
  if point_prompt_map is not None:
74
  binary_points = F.interpolate(
75
  point_prompt_map.float(),
@@ -88,7 +86,7 @@ class SegVolModel(PreTrainedModel):
88
  image_single_cropped = image[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
89
  global_preds = (torch.sigmoid(logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
90
 
91
- assert not (bbox_prompt is not None and point_prompt is not None)
92
  prompt_reflection = None
93
  if bbox_prompt is not None:
94
  binary_cube_cropped = binary_cube[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
 
52
  bbox_prompt, bbox_prompt_map = bbox_prompt
53
  if point_prompt is not None:
54
  point_prompt, point_prompt_map = point_prompt
55
+ # print(image.shape, zoomed_image.shape, text_prompt)
56
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
57
+ # print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
58
  volume_shape = image[0][0].shape
59
 
60
  with torch.no_grad():
 
62
  text=text_prompt,
63
  boxes=bbox_prompt,
64
  points=point_prompt)
 
65
  logits_global_single = F.interpolate(
66
  logits_global_single.cpu(),
67
  size=volume_shape, mode='nearest')
68
  if not use_zoom:
69
  return logits_global_single
70
+
 
71
  if point_prompt_map is not None:
72
  binary_points = F.interpolate(
73
  point_prompt_map.float(),
 
86
  image_single_cropped = image[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
87
  global_preds = (torch.sigmoid(logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
88
 
89
+ assert not (bbox_prompt is not None and point_prompt is not None), 'Do not use point prompt and box prompt at the same time.'
90
  prompt_reflection = None
91
  if bbox_prompt is not None:
92
  binary_cube_cropped = binary_cube[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]