Martin Tomov commited on
Commit
74d1fdf
·
verified ·
1 Parent(s): 212f0f5

spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -92,19 +92,21 @@ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> L
92
  masks[idx] = cv2.fillPoly(np.zeros(shape, dtype=np.uint8), [polygon], 1)
93
  return list(masks)
94
 
 
95
  def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[DetectionResult]:
96
  detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
97
- object_detector = pipeline(model=detector_id, task="zero-shot-object-detection")
98
  labels = [label if label.endswith(".") else label + "." for label in labels]
99
  results = object_detector(image, candidate_labels=labels, threshold=threshold)
100
  return [DetectionResult.from_dict(result) for result in results]
101
 
 
102
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
103
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
104
- segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id)
105
  processor = AutoProcessor.from_pretrained(segmenter_id)
106
  boxes = get_boxes(detection_results)
107
- inputs = processor(images=image, input_boxes=boxes, return_tensors="pt")
108
  outputs = segmentator(**inputs)
109
  masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
110
  masks = refine_masks(masks, polygon_refinement)
 
92
  masks[idx] = cv2.fillPoly(np.zeros(shape, dtype=np.uint8), [polygon], 1)
93
  return list(masks)
94
 
95
+ @spaces.GPU
96
  def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[DetectionResult]:
97
  detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
98
+ object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=0)
99
  labels = [label if label.endswith(".") else label + "." for label in labels]
100
  results = object_detector(image, candidate_labels=labels, threshold=threshold)
101
  return [DetectionResult.from_dict(result) for result in results]
102
 
103
+ @spaces.GPU
104
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
105
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
106
+ segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to("cuda")
107
  processor = AutoProcessor.from_pretrained(segmenter_id)
108
  boxes = get_boxes(detection_results)
109
+ inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to("cuda")
110
  outputs = segmentator(**inputs)
111
  masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
112
  masks = refine_masks(masks, polygon_refinement)