from typing import List, Tuple import modal import numpy as np import supervision as sv import torch from PIL import Image from transformers import SamHQModel, SamHQProcessor from .app import app from .image import image def get_detections_from_segment_anything(bounding_boxes, list_of_masks, iou_scores): detections = sv.Detections( xyxy=np.array(bounding_boxes), mask=np.array(list_of_masks), class_id=np.array(list(range(len(bounding_boxes)))), confidence=np.array(iou_scores), ) return detections @app.cls(gpu="T4", image=image) class SegmentAnythingModalApp: @modal.enter() def setup(self): model_name = "syscv-community/sam-hq-vit-huge" # model_name = "syscv-community/sam-hq-vit-base" self.model = SamHQModel.from_pretrained(model_name) self.model.to("cuda") self.model.eval() self.processor: SamHQProcessor = SamHQProcessor.from_pretrained(model_name) @modal.method() def forward(self, image, bounding_boxes: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]: bounding_boxes = [bounding_boxes] inputs = self.processor(image, input_boxes=bounding_boxes, return_tensors="pt") inputs = inputs.to("cuda") with torch.no_grad(): outputs = self.model(**inputs) pred_masks = outputs.pred_masks original_sizes = inputs["original_sizes"] reshaped_input_sizes = inputs["reshaped_input_sizes"] batched_masks = self.processor.post_process_masks(pred_masks, original_sizes, reshaped_input_sizes) batched_masks = batched_masks[0][:, 0] iou_scores = outputs.iou_scores[0, :, 0] batched_masks = batched_masks.cpu().numpy() iou_scores = iou_scores.cpu().numpy() return batched_masks, iou_scores if __name__ == "__main__": import supervision as sv from PIL import Image segment_anything = modal.Cls.from_name(app.name, SegmentAnythingModalApp.__name__)() image = Image.open("images/image.png") bounding_boxes = [ [2449, 2021, 2758, 2359], [436, 1942, 1002, 2193], [2785, 1945, 3259, 2374], [3285, 1996, 3721, 2405], [1968, 2035, 2451, 2474], [1741, 1909, 2098, 2320], ] masks, iou_scores = segment_anything.forward.remote(image=image, bounding_boxes=bounding_boxes) detections = get_detections_from_segment_anything(bounding_boxes, masks, iou_scores) mask_annotations = sv.MaskAnnotator() annotated_image = mask_annotations.annotate(scene=image, detections=detections) annotated_image.save("images/resized_image_with_detections_annotated.png")