File size: 5,421 Bytes
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111afa2
 
 
 
 
 
 
518d841
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
 
 
 
 
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Dict, List, Union

import numpy as np
import PIL
import supervision as sv
from smolagents import Tool


def get_class_ids_from_labels(labels: List[str]):
    unique_labels = list(set(labels))
    label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
    class_ids = [label_to_id[label] for label in labels]
    return class_ids


def create_detections_from_image_segmentation_output(
    image_segmentation_output: List[Dict[str, Union[str, PIL.Image.Image]]],
):
    masks = [detection["mask"] for detection in image_segmentation_output]
    xyxy = []
    for mask in masks:
        mask_array = np.array(mask)
        y_indices, x_indices = np.where(mask_array > 0)
        if len(y_indices) > 0 and len(x_indices) > 0:
            xmin, xmax = np.min(x_indices), np.max(x_indices)
            ymin, ymax = np.min(y_indices), np.max(y_indices)
            xyxy.append((xmin, ymin, xmax, ymax))

    masks = [np.array(mask_array) > 0 for mask_array in masks]
    labels = [detection["label"] for detection in image_segmentation_output]
    class_ids = get_class_ids_from_labels(labels)

    detections = sv.Detections(
        xyxy=np.array(xyxy),
        mask=np.array(masks),
        class_id=np.array(class_ids),
        metadata={"labels": labels},
    )
    return detections


def create_detections_from_object_detection_output(
    object_detection_output: List[Dict[str, Union[str, Dict[str, float], List]]],
):
    bboxes = [
        [detection["box"]["xmin"], detection["box"]["ymin"], detection["box"]["xmax"], detection["box"]["ymax"]]
        for detection in object_detection_output
    ]
    labels = [detection["label"] for detection in object_detection_output]
    # Create a mapping of unique labels to integers
    unique_labels = list(set(labels))
    label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
    class_ids = [label_to_id[label] for label in labels]

    detections = sv.Detections(
        xyxy=np.array(bboxes),
        confidence=np.array([detection["score"] for detection in object_detection_output]),
        class_id=np.array(class_ids),
        metadata={"labels": labels},
    )
    return detections


def create_detections_from_segment_anything(
    bounding_boxes: List[List[List[int]]],
    masks: List[List[np.ndarray]],
    iou_scores: List[List[float]],
    segment_anything_output: List[Dict[str, Union[str, np.ndarray, float]]],
):
    bounding_boxes = [segmentation["box"] for segmentation in segment_anything_output]
    masks = [segmentation["mask"] for segmentation in segment_anything_output]
    iou_scores = [segmentation["score"] for segmentation in segment_anything_output]

    detections = sv.Detections(
        xyxy=np.array(bounding_boxes),
        mask=np.array(masks),
        class_id=np.array(list(range(len(bounding_boxes)))),
        confidence=np.array(iou_scores),
    )
    return detections


class TaskInferenceOutputConverterTool(Tool):
    name = "task_inference_output_converter"
    description = """
        Given a task inference output, convert it to a list of detections that can be used to annotate the image.
        The supported tasks are:
        - object-detection
        - image-segmentation
        - segment-anything

        In case of object-detection, the task inference output is a list of dictionaries with the following keys:
        - label: a string.
        - score: a number between 0 and 1.
        - box: a dictionary with the following keys:
            - xmin: a number
            - ymin: a number
            - xmax: a number
            - ymax: a number
        
        In case of image-segmentation, the task inference output is a list of dictionaries with the following keys:
        - label: a string.
        - mask: a PIL image of shape (height, width) with values in {0, 1}.
        - score: an optional number between 0 and 1, can be None.

        In case of segment-anything, the task inference output is a list of dictionaries with the following keys:
        - bounding_boxes: a list of lists of bounding boxes.
        - masks: a list of lists of masks.
        - iou_scores: a list of lists of iou scores.

        The output is a list of detections that can be used to annotate the image.
        The detections is an object of type supervision.Detections.
    """

    inputs = {
        "task_inference_output": {
            "type": "array",
            "description": "The task inference output to convert to detections",
        },
        "task": {
            "type": "array",
            "description": """
            The task to convert the task inference output to detections for.
            The supported tasks are:
            - object-detection
            - image-segmentation
            - segment-anything
            """,
        },
    }
    output_type = "object"

    def __init__(self):
        super().__init__()

    def forward(
        self,
        task_inference_output: List[Dict[str, Union[str, float, Dict[str, float]]]],
        task: str,
    ):
        if task == "object-detection":
            result = create_detections_from_object_detection_output(task_inference_output)
        elif task == "image-segmentation":
            result = create_detections_from_image_segmentation_output(task_inference_output)
        else:
            raise ValueError(f"Task {task} is not supported")
        return result