File size: 2,663 Bytes
518d841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import List, Tuple

import modal
import numpy as np
import supervision as sv
import torch
from PIL import Image

from transformers import SamHQModel, SamHQProcessor

from .app import app
from .image import image


def get_detections_from_segment_anything(bounding_boxes, list_of_masks, iou_scores):
    detections = sv.Detections(
        xyxy=np.array(bounding_boxes),
        mask=np.array(list_of_masks),
        class_id=np.array(list(range(len(bounding_boxes)))),
        confidence=np.array(iou_scores),
    )
    return detections


@app.cls(gpu="T4", image=image)
class SegmentAnythingModalApp:
    @modal.enter()
    def setup(self):
        model_name = "syscv-community/sam-hq-vit-huge"
        # model_name = "syscv-community/sam-hq-vit-base"
        self.model = SamHQModel.from_pretrained(model_name)
        self.model.to("cuda")
        self.model.eval()
        self.processor: SamHQProcessor = SamHQProcessor.from_pretrained(model_name)

    @modal.method()
    def forward(self, image, bounding_boxes: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
        bounding_boxes = [bounding_boxes]
        inputs = self.processor(image, input_boxes=bounding_boxes, return_tensors="pt")
        inputs = inputs.to("cuda")
        with torch.no_grad():
            outputs = self.model(**inputs)
        pred_masks = outputs.pred_masks
        original_sizes = inputs["original_sizes"]
        reshaped_input_sizes = inputs["reshaped_input_sizes"]
        batched_masks = self.processor.post_process_masks(pred_masks, original_sizes, reshaped_input_sizes)
        batched_masks = batched_masks[0][:, 0]
        iou_scores = outputs.iou_scores[0, :, 0]
        batched_masks = batched_masks.cpu().numpy()
        iou_scores = iou_scores.cpu().numpy()

        return batched_masks, iou_scores


if __name__ == "__main__":
    import supervision as sv
    from PIL import Image

    segment_anything = modal.Cls.from_name(app.name, SegmentAnythingModalApp.__name__)()
    image = Image.open("images/image.png")
    bounding_boxes = [
        [2449, 2021, 2758, 2359],
        [436, 1942, 1002, 2193],
        [2785, 1945, 3259, 2374],
        [3285, 1996, 3721, 2405],
        [1968, 2035, 2451, 2474],
        [1741, 1909, 2098, 2320],
    ]
    masks, iou_scores = segment_anything.forward.remote(image=image, bounding_boxes=bounding_boxes)

    detections = get_detections_from_segment_anything(bounding_boxes, masks, iou_scores)
    mask_annotations = sv.MaskAnnotator()
    annotated_image = mask_annotations.annotate(scene=image, detections=detections)

    annotated_image.save("images/resized_image_with_detections_annotated.png")