Martin Tomov commited on
Commit
ca021c4
Β·
verified Β·
1 Parent(s): 4c2297f

cropped bbox output

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -109,7 +109,7 @@ def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detect
109
  results = object_detector(image, candidate_labels=labels, threshold=threshold)
110
  return [DetectionResult.from_dict(result) for result in results]
111
 
112
- @spaces.GPU
113
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
114
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
115
  segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to("cuda")
@@ -190,19 +190,33 @@ def detections_to_json(detections):
190
  detections_list.append(detection_dict)
191
  return detections_list
192
 
 
 
 
 
 
 
 
 
193
  def process_image(image, include_json, include_bboxes):
194
  labels = ["insect"]
195
  original_image, detections = grounded_segmentation(image, labels, threshold=0.3, polygon_refinement=True)
196
  yellow_background_with_insects = create_yellow_background_with_insects(np.array(original_image), detections)
197
  annotated_image = plot_detections(yellow_background_with_insects, detections, include_bboxes)
 
 
 
 
 
 
198
  if include_json:
199
  detections_json = detections_to_json(detections)
200
  json_output_path = "insect_detections.json"
201
  with open(json_output_path, 'w') as json_file:
202
  json.dump(detections_json, json_file, indent=4)
203
- return annotated_image, json.dumps(detections_json, separators=(',', ':'))
204
- else:
205
- return annotated_image, None
206
 
207
  examples = [
208
  ["flower-night.jpg"]
@@ -211,7 +225,7 @@ examples = [
211
  gr.Interface(
212
  fn=process_image,
213
  inputs=[gr.Image(type="pil"), gr.Checkbox(label="Include JSON", value=False), gr.Checkbox(label="Include Bounding Boxes", value=False)],
214
- outputs=[gr.Image(type="numpy"), gr.Textbox()],
215
  title="InsectSAM 🐞",
216
  examples=examples
217
  ).launch()
 
109
  results = object_detector(image, candidate_labels=labels, threshold=threshold)
110
  return [DetectionResult.from_dict(result) for result in results]
111
 
112
+ @spaces.GGPU
113
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
114
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
115
  segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to("cuda")
 
190
  detections_list.append(detection_dict)
191
  return detections_list
192
 
193
+ def crop_bounding_boxes(image: np.ndarray, detections: List[DetectionResult]) -> List[np.ndarray]:
194
+ crops = []
195
+ for detection in detections:
196
+ xmin, ymin, xmax, ymax = detection.box.xyxy
197
+ crop = image[ymin:ymax, xmin:xmax]
198
+ crops.append(crop)
199
+ return crops
200
+
201
  def process_image(image, include_json, include_bboxes):
202
  labels = ["insect"]
203
  original_image, detections = grounded_segmentation(image, labels, threshold=0.3, polygon_refinement=True)
204
  yellow_background_with_insects = create_yellow_background_with_insects(np.array(original_image), detections)
205
  annotated_image = plot_detections(yellow_background_with_insects, detections, include_bboxes)
206
+
207
+ results = [annotated_image]
208
+ if include_bboxes:
209
+ crops = crop_bounding_boxes(np.array(original_image), detections)
210
+ results.extend(crops)
211
+
212
  if include_json:
213
  detections_json = detections_to_json(detections)
214
  json_output_path = "insect_detections.json"
215
  with open(json_output_path, 'w') as json_file:
216
  json.dump(detections_json, json_file, indent=4)
217
+ results.append(json.dumps(detections_json, separators=(',', ':')))
218
+
219
+ return tuple(results)
220
 
221
  examples = [
222
  ["flower-night.jpg"]
 
225
  gr.Interface(
226
  fn=process_image,
227
  inputs=[gr.Image(type="pil"), gr.Checkbox(label="Include JSON", value=False), gr.Checkbox(label="Include Bounding Boxes", value=False)],
228
+ outputs=[gr.Image(type="numpy")] + [gr.Image(type="numpy")] * 10 + [gr.Textbox()],
229
  title="InsectSAM 🐞",
230
  examples=examples
231
  ).launch()