Spaces:
Running
Running
File size: 4,368 Bytes
111afa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from typing import Dict, List, Union
import numpy as np
import PIL
import supervision as sv
from smolagents import Tool
def get_class_ids_from_labels(labels: List[str]):
unique_labels = list(set(labels))
label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
class_ids = [label_to_id[label] for label in labels]
return class_ids
def create_detections_from_image_segmentation_output(
image_segmentation_output: List[Dict[str, Union[str, PIL.Image.Image]]],
):
masks = [detection["mask"] for detection in image_segmentation_output]
xyxy = []
for mask in masks:
mask_array = np.array(mask)
y_indices, x_indices = np.where(mask_array > 0)
if len(y_indices) > 0 and len(x_indices) > 0:
xmin, xmax = np.min(x_indices), np.max(x_indices)
ymin, ymax = np.min(y_indices), np.max(y_indices)
xyxy.append((xmin, ymin, xmax, ymax))
masks = [np.array(mask_array) > 0 for mask_array in masks]
labels = [detection["label"] for detection in image_segmentation_output]
class_ids = get_class_ids_from_labels(labels)
detections = sv.Detections(
xyxy=np.array(xyxy),
mask=np.array(masks),
class_id=np.array(class_ids),
metadata={"labels": labels},
)
return detections
def create_detections_from_object_detection_output(
object_detection_output: List[Dict[str, Union[str, Dict[str, float], List]]],
):
bboxes = [
[detection["box"]["xmin"], detection["box"]["ymin"], detection["box"]["xmax"], detection["box"]["ymax"]]
for detection in object_detection_output
]
labels = [detection["label"] for detection in object_detection_output]
# Create a mapping of unique labels to integers
unique_labels = list(set(labels))
label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
class_ids = [label_to_id[label] for label in labels]
detections = sv.Detections(
xyxy=np.array(bboxes),
confidence=np.array([detection["score"] for detection in object_detection_output]),
class_id=np.array(class_ids),
metadata={"labels": labels},
)
return detections
class TaskInferenceOutputConverterTool(Tool):
name = "task_inference_output_converter"
description = """
Given a task inference output, convert it to a list of detections that can be used to annotate the image.
The supported tasks are:
- object-detection
- image-segmentation
In case of object-detection, the task inference output is a list of dictionaries with the following keys:
- label: a string.
- score: a number between 0 and 1.
- box: a dictionary with the following keys:
- xmin: a number
- ymin: a number
- xmax: a number
- ymax: a number
In case of image-segmentation, the task inference output is a list of dictionaries with the following keys:
- label: a string.
- mask: a PIL image of shape (height, width) with values in {0, 1}.
- score: an optional number between 0 and 1, can be None.
The output is a list of detections that can be used to annotate the image.
The detections is an object of type supervision.Detections.
"""
inputs = {
"task_inference_output": {
"type": "array",
"description": "The task inference output to convert to detections",
},
"task": {
"type": "array",
"description": """
The task to convert the task inference output to detections for.
The supported tasks are:
- object-detection
- image-segmentation
""",
},
}
output_type = "object"
def __init__(self):
super().__init__()
def forward(
self,
task_inference_output: List[Dict[str, Union[str, float, Dict[str, float]]]],
task: str,
):
if task == "object-detection":
result = create_detections_from_object_detection_output(task_inference_output)
elif task == "image-segmentation":
result = create_detections_from_image_segmentation_output(task_inference_output)
else:
raise ValueError(f"Task {task} is not supported")
return result
|