File size: 1,948 Bytes
518d841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import modal
import numpy as np
import supervision as sv
from smolagents import Tool

from modal_apps.app import app
from modal_apps.segment_anything import SegmentAnythingModalApp


def get_detections_from_segment_anything(detections, list_of_masks, iou_scores):
    bounding_boxes = detections.xyxy.tolist()
    detections = sv.Detections(
        xyxy=np.array(bounding_boxes),
        mask=np.array(list_of_masks),
        class_id=np.array(list(range(len(bounding_boxes)))),
        confidence=np.array(iou_scores),
    )
    return detections


class SegmentAnythingTool(Tool):
    name = "segment_anything"
    description = """
        Given an image and an already detected object (a sv.Detections object), segment the image and return masks for each bounding box.
        The image is a PIL image.
        The detections are an object of type sv.Detections, obtainable from the usage of the object_detection tool with task_inference_output_converter.
        
        The output is the same as the input, but with the masks added.
    """

    inputs = {
        "image": {
            "type": "image",
            "description": "The image to segment",
        },
        "detections": {
            "type": "object",
            "description": """
            The detections to segment the image with. 
            The detections are an object of type supervision.Detections.
            """,
        },
    }
    output_type = "object"

    def __init__(self):
        super().__init__()
        self.modal_app = modal.Cls.from_name(app.name, SegmentAnythingModalApp.__name__)()

    def forward(
        self,
        image,
        detections: sv.Detections,
    ):
        bounding_boxes = detections.xyxy.tolist()
        masks, iou_scores = self.modal_app.forward.remote(image=image, bounding_boxes=bounding_boxes)
        detections = get_detections_from_segment_anything(detections, masks, iou_scores)
        return detections