Spaces:
Running
Running
from typing import Dict, List, Union | |
import numpy as np | |
import PIL | |
import supervision as sv | |
from smolagents import Tool | |
def get_class_ids_from_labels(labels: List[str]): | |
unique_labels = list(set(labels)) | |
label_to_id = {label: idx for idx, label in enumerate(unique_labels)} | |
class_ids = [label_to_id[label] for label in labels] | |
return class_ids | |
def create_detections_from_image_segmentation_output( | |
image_segmentation_output: List[Dict[str, Union[str, PIL.Image.Image]]], | |
): | |
masks = [detection["mask"] for detection in image_segmentation_output] | |
xyxy = [] | |
for mask in masks: | |
mask_array = np.array(mask) | |
y_indices, x_indices = np.where(mask_array > 0) | |
if len(y_indices) > 0 and len(x_indices) > 0: | |
xmin, xmax = np.min(x_indices), np.max(x_indices) | |
ymin, ymax = np.min(y_indices), np.max(y_indices) | |
xyxy.append((xmin, ymin, xmax, ymax)) | |
masks = [np.array(mask_array) > 0 for mask_array in masks] | |
labels = [detection["label"] for detection in image_segmentation_output] | |
class_ids = get_class_ids_from_labels(labels) | |
detections = sv.Detections( | |
xyxy=np.array(xyxy), | |
mask=np.array(masks), | |
class_id=np.array(class_ids), | |
metadata={"labels": labels}, | |
) | |
return detections | |
def create_detections_from_object_detection_output( | |
object_detection_output: List[Dict[str, Union[str, Dict[str, float], List]]], | |
): | |
bboxes = [ | |
[detection["box"]["xmin"], detection["box"]["ymin"], detection["box"]["xmax"], detection["box"]["ymax"]] | |
for detection in object_detection_output | |
] | |
labels = [detection["label"] for detection in object_detection_output] | |
# Create a mapping of unique labels to integers | |
unique_labels = list(set(labels)) | |
label_to_id = {label: idx for idx, label in enumerate(unique_labels)} | |
class_ids = [label_to_id[label] for label in labels] | |
detections = sv.Detections( | |
xyxy=np.array(bboxes), | |
confidence=np.array([detection["score"] for detection in object_detection_output]), | |
class_id=np.array(class_ids), | |
metadata={"labels": labels}, | |
) | |
return detections | |
class TaskInferenceOutputConverterTool(Tool): | |
name = "task_inference_output_converter" | |
description = """ | |
Given a task inference output, convert it to a list of detections that can be used to annotate the image. | |
The supported tasks are: | |
- object-detection | |
- image-segmentation | |
In case of object-detection, the task inference output is a list of dictionaries with the following keys: | |
- label: a string. | |
- score: a number between 0 and 1. | |
- box: a dictionary with the following keys: | |
- xmin: a number | |
- ymin: a number | |
- xmax: a number | |
- ymax: a number | |
In case of image-segmentation, the task inference output is a list of dictionaries with the following keys: | |
- label: a string. | |
- mask: a PIL image of shape (height, width) with values in {0, 1}. | |
- score: an optional number between 0 and 1, can be None. | |
The output is a list of detections that can be used to annotate the image. | |
The detections is an object of type supervision.Detections. | |
""" | |
inputs = { | |
"task_inference_output": { | |
"type": "array", | |
"description": "The task inference output to convert to detections", | |
}, | |
"task": { | |
"type": "array", | |
"description": """ | |
The task to convert the task inference output to detections for. | |
The supported tasks are: | |
- object-detection | |
- image-segmentation | |
""", | |
}, | |
} | |
output_type = "object" | |
def __init__(self): | |
super().__init__() | |
def forward( | |
self, | |
task_inference_output: List[Dict[str, Union[str, float, Dict[str, float]]]], | |
task: str, | |
): | |
if task == "object-detection": | |
result = create_detections_from_object_detection_output(task_inference_output) | |
elif task == "image-segmentation": | |
result = create_detections_from_image_segmentation_output(task_inference_output) | |
else: | |
raise ValueError(f"Task {task} is not supported") | |
return result | |