dillonlaird commited on
Commit
9a242d9
·
1 Parent(s): 0866d30

update for bboxes

Browse files
Files changed (2) hide show
  1. app/main.py +2 -0
  2. app/per_sam/model.py +4 -1
app/main.py CHANGED
@@ -176,6 +176,8 @@ async def get_multi_label_preds(image: str, q: MaskLabel) -> MaskBoxLabels:
176
  start = time.perf_counter()
177
  masks, bboxes, _ = per_sam_model(image_np)
178
  print(f"inference time {time.perf_counter() - start}")
 
 
179
  masks_out = []
180
  for i in range(len(masks)):
181
  mask_i = Image.fromarray(masks[i])
 
176
  start = time.perf_counter()
177
  masks, bboxes, _ = per_sam_model(image_np)
178
  print(f"inference time {time.perf_counter() - start}")
179
+ if masks is None:
180
+ return MaskBoxLabels(masks=[], bboxes=[], labels=[])
181
  masks_out = []
182
  for i in range(len(masks)):
183
  mask_i = Image.fromarray(masks[i])
app/per_sam/model.py CHANGED
@@ -101,7 +101,7 @@ def fast_inference(
101
  max_objects: int,
102
  score_thresh: float,
103
  nms_iou_thresh: float = 0.2,
104
- ) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
105
  weights_np = weights.detach().cpu().numpy()
106
  pred_masks = []
107
  pred_scores = []
@@ -190,6 +190,9 @@ def fast_inference(
190
  pred_masks.append(final_mask)
191
  pred_scores.append(score)
192
 
 
 
 
193
  pred_masks = torch.stack(pred_masks)
194
  bboxes = batched_mask_to_box(pred_masks)
195
  keep_by_nms = batched_nms(
 
101
  max_objects: int,
102
  score_thresh: float,
103
  nms_iou_thresh: float = 0.2,
104
+ ) -> tuple[npt.NDArray | None, npt.NDArray | None, npt.NDArray | None]:
105
  weights_np = weights.detach().cpu().numpy()
106
  pred_masks = []
107
  pred_scores = []
 
190
  pred_masks.append(final_mask)
191
  pred_scores.append(score)
192
 
193
+ if len(pred_masks) == 0:
194
+ return None, None, None
195
+
196
  pred_masks = torch.stack(pred_masks)
197
  bboxes = batched_mask_to_box(pred_masks)
198
  keep_by_nms = batched_nms(