ScouterAI / tools /inference_converter.py
stevenbucaille's picture
Enhance image processing capabilities and update project structure
111afa2
raw
history blame
4.37 kB
from typing import Dict, List, Union
import numpy as np
import PIL
import supervision as sv
from smolagents import Tool
def get_class_ids_from_labels(labels: List[str]):
unique_labels = list(set(labels))
label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
class_ids = [label_to_id[label] for label in labels]
return class_ids
def create_detections_from_image_segmentation_output(
image_segmentation_output: List[Dict[str, Union[str, PIL.Image.Image]]],
):
masks = [detection["mask"] for detection in image_segmentation_output]
xyxy = []
for mask in masks:
mask_array = np.array(mask)
y_indices, x_indices = np.where(mask_array > 0)
if len(y_indices) > 0 and len(x_indices) > 0:
xmin, xmax = np.min(x_indices), np.max(x_indices)
ymin, ymax = np.min(y_indices), np.max(y_indices)
xyxy.append((xmin, ymin, xmax, ymax))
masks = [np.array(mask_array) > 0 for mask_array in masks]
labels = [detection["label"] for detection in image_segmentation_output]
class_ids = get_class_ids_from_labels(labels)
detections = sv.Detections(
xyxy=np.array(xyxy),
mask=np.array(masks),
class_id=np.array(class_ids),
metadata={"labels": labels},
)
return detections
def create_detections_from_object_detection_output(
object_detection_output: List[Dict[str, Union[str, Dict[str, float], List]]],
):
bboxes = [
[detection["box"]["xmin"], detection["box"]["ymin"], detection["box"]["xmax"], detection["box"]["ymax"]]
for detection in object_detection_output
]
labels = [detection["label"] for detection in object_detection_output]
# Create a mapping of unique labels to integers
unique_labels = list(set(labels))
label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
class_ids = [label_to_id[label] for label in labels]
detections = sv.Detections(
xyxy=np.array(bboxes),
confidence=np.array([detection["score"] for detection in object_detection_output]),
class_id=np.array(class_ids),
metadata={"labels": labels},
)
return detections
class TaskInferenceOutputConverterTool(Tool):
name = "task_inference_output_converter"
description = """
Given a task inference output, convert it to a list of detections that can be used to annotate the image.
The supported tasks are:
- object-detection
- image-segmentation
In case of object-detection, the task inference output is a list of dictionaries with the following keys:
- label: a string.
- score: a number between 0 and 1.
- box: a dictionary with the following keys:
- xmin: a number
- ymin: a number
- xmax: a number
- ymax: a number
In case of image-segmentation, the task inference output is a list of dictionaries with the following keys:
- label: a string.
- mask: a PIL image of shape (height, width) with values in {0, 1}.
- score: an optional number between 0 and 1, can be None.
The output is a list of detections that can be used to annotate the image.
The detections is an object of type supervision.Detections.
"""
inputs = {
"task_inference_output": {
"type": "array",
"description": "The task inference output to convert to detections",
},
"task": {
"type": "array",
"description": """
The task to convert the task inference output to detections for.
The supported tasks are:
- object-detection
- image-segmentation
""",
},
}
output_type = "object"
def __init__(self):
super().__init__()
def forward(
self,
task_inference_output: List[Dict[str, Union[str, float, Dict[str, float]]]],
task: str,
):
if task == "object-detection":
result = create_detections_from_object_detection_output(task_inference_output)
elif task == "image-segmentation":
result = create_detections_from_image_segmentation_output(task_inference_output)
else:
raise ValueError(f"Task {task} is not supported")
return result