File size: 1,771 Bytes
111afa2
 
 
 
 
 
 
 
 
 
518d841
111afa2
518d841
111afa2
 
 
 
518d841
 
111afa2
 
 
 
 
 
 
 
 
 
 
518d841
 
 
 
111afa2
 
 
 
 
 
 
 
 
 
 
518d841
111afa2
518d841
 
 
111afa2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import modal
from smolagents import Tool

from modal_apps.app import app
from modal_apps.inference_pipeline import InferencePipelineModalApp


class ImageSegmentationTool(Tool):
    name = "image_segmentation"
    description = """
        Given an image, segment the image and return the masks.
        The image is a PIL image.
        The output is a list of dictionaries containing the masks with the following keys:
        - score: an optional number between 0 and 1, can be None.
        - label: a string
        - mask: a PIL image
        You need to provide the model name to use for image segmentation.
        The tool returns a list of masks for all the objects in the image.
        You also need to provide a score threshold to filter the masks.
    """

    inputs = {
        "image": {
            "type": "image",
            "description": "The image to segment",
        },
        "model_name": {
            "type": "string",
            "description": "The name of the model to use for image segmentation",
        },
        "threshold": {
            "type": "number",
            "description": "The score threshold of the masks to return",
        },
    }
    output_type = "object"

    def __init__(self):
        super().__init__()
        self.modal_app = modal.Cls.from_name(app.name, InferencePipelineModalApp.__name__)()

    def forward(
        self,
        image,
        model_name: str,
        threshold: float,
    ):
        segments = self.modal_app.forward.remote(
            model_name=model_name, task="image-segmentation", image=image, threshold=threshold
        )
        print("Segments: ", segments)
        for segment in segments:
            print(f"Found segment of {segment['label']}")
        return segments