|
from typing import List, Optional, Tuple |
|
|
|
import numpy as np |
|
|
|
from inference.core.entities.responses.inference import ( |
|
InferenceResponseImage, |
|
Keypoint, |
|
KeypointsDetectionInferenceResponse, |
|
KeypointsPrediction, |
|
) |
|
from inference.core.exceptions import ModelArtefactError |
|
from inference.core.models.object_detection_base import ( |
|
ObjectDetectionBaseOnnxRoboflowInferenceModel, |
|
) |
|
from inference.core.models.types import PreprocessReturnMetadata |
|
from inference.core.models.utils.keypoints import model_keypoints_to_response |
|
from inference.core.models.utils.validate import ( |
|
get_num_classes_from_model_prediction_shape, |
|
) |
|
from inference.core.nms import w_np_non_max_suppression |
|
from inference.core.utils.postprocess import post_process_bboxes, post_process_keypoints |
|
|
|
DEFAULT_CONFIDENCE = 0.4 |
|
DEFAULT_IOU_THRESH = 0.3 |
|
DEFAULT_CLASS_AGNOSTIC_NMS = False |
|
DEFAUlT_MAX_DETECTIONS = 300 |
|
DEFAULT_MAX_CANDIDATES = 3000 |
|
|
|
|
|
class KeypointsDetectionBaseOnnxRoboflowInferenceModel( |
|
ObjectDetectionBaseOnnxRoboflowInferenceModel |
|
): |
|
"""Roboflow ONNX Object detection model. This class implements an object detection specific infer method.""" |
|
|
|
task_type = "keypoint-detection" |
|
|
|
def __init__(self, model_id: str, *args, **kwargs): |
|
super().__init__(model_id, *args, **kwargs) |
|
|
|
def get_infer_bucket_file_list(self) -> list: |
|
"""Returns the list of files to be downloaded from the inference bucket for ONNX model. |
|
|
|
Returns: |
|
list: A list of filenames specific to ONNX models. |
|
""" |
|
return ["environment.json", "class_names.txt", "keypoints_metadata.json"] |
|
|
|
def postprocess( |
|
self, |
|
predictions: Tuple[np.ndarray], |
|
preproc_return_metadata: PreprocessReturnMetadata, |
|
class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS, |
|
confidence: float = DEFAULT_CONFIDENCE, |
|
iou_threshold: float = DEFAULT_IOU_THRESH, |
|
max_candidates: int = DEFAULT_MAX_CANDIDATES, |
|
max_detections: int = DEFAUlT_MAX_DETECTIONS, |
|
return_image_dims: bool = False, |
|
**kwargs, |
|
) -> List[KeypointsDetectionInferenceResponse]: |
|
"""Postprocesses the object detection predictions. |
|
|
|
Args: |
|
predictions (np.ndarray): Raw predictions from the model. |
|
img_dims (List[Tuple[int, int]]): Dimensions of the images. |
|
class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False. |
|
confidence (float): Confidence threshold for filtering detections. Default is 0.5. |
|
iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5. |
|
max_candidates (int): Maximum number of candidate detections. Default is 3000. |
|
max_detections (int): Maximum number of final detections. Default is 300. |
|
|
|
Returns: |
|
List[KeypointsDetectionInferenceResponse]: The post-processed predictions. |
|
""" |
|
predictions = predictions[0] |
|
number_of_classes = len(self.get_class_names) |
|
num_masks = predictions.shape[2] - 5 - number_of_classes |
|
predictions = w_np_non_max_suppression( |
|
predictions, |
|
conf_thresh=confidence, |
|
iou_thresh=iou_threshold, |
|
class_agnostic=class_agnostic_nms, |
|
max_detections=max_detections, |
|
max_candidate_detections=max_candidates, |
|
num_masks=num_masks, |
|
) |
|
|
|
infer_shape = (self.img_size_h, self.img_size_w) |
|
img_dims = preproc_return_metadata["img_dims"] |
|
predictions = post_process_bboxes( |
|
predictions=predictions, |
|
infer_shape=infer_shape, |
|
img_dims=img_dims, |
|
preproc=self.preproc, |
|
resize_method=self.resize_method, |
|
disable_preproc_static_crop=preproc_return_metadata[ |
|
"disable_preproc_static_crop" |
|
], |
|
) |
|
predictions = post_process_keypoints( |
|
predictions=predictions, |
|
keypoints_start_index=-num_masks, |
|
infer_shape=infer_shape, |
|
img_dims=img_dims, |
|
preproc=self.preproc, |
|
resize_method=self.resize_method, |
|
disable_preproc_static_crop=preproc_return_metadata[ |
|
"disable_preproc_static_crop" |
|
], |
|
) |
|
return self.make_response(predictions, img_dims, **kwargs) |
|
|
|
def make_response( |
|
self, |
|
predictions: List[List[float]], |
|
img_dims: List[Tuple[int, int]], |
|
class_filter: Optional[List[str]] = None, |
|
*args, |
|
**kwargs, |
|
) -> List[KeypointsDetectionInferenceResponse]: |
|
"""Constructs object detection response objects based on predictions. |
|
|
|
Args: |
|
predictions (List[List[float]]): The list of predictions. |
|
img_dims (List[Tuple[int, int]]): Dimensions of the images. |
|
class_filter (Optional[List[str]]): A list of class names to filter, if provided. |
|
|
|
Returns: |
|
List[KeypointsDetectionInferenceResponse]: A list of response objects containing keypoints detection predictions. |
|
""" |
|
if isinstance(img_dims, dict) and "img_dims" in img_dims: |
|
img_dims = img_dims["img_dims"] |
|
keypoint_confidence_threshold = 0.0 |
|
if "request" in kwargs: |
|
keypoint_confidence_threshold = kwargs["request"].keypoint_confidence |
|
responses = [ |
|
KeypointsDetectionInferenceResponse( |
|
predictions=[ |
|
KeypointsPrediction( |
|
|
|
**{ |
|
"x": (pred[0] + pred[2]) / 2, |
|
"y": (pred[1] + pred[3]) / 2, |
|
"width": pred[2] - pred[0], |
|
"height": pred[3] - pred[1], |
|
"confidence": pred[4], |
|
"class": self.class_names[int(pred[6])], |
|
"class_id": int(pred[6]), |
|
"keypoints": model_keypoints_to_response( |
|
keypoints_metadata=self.keypoints_metadata, |
|
keypoints=pred[7:], |
|
predicted_object_class_id=int( |
|
pred[4 + len(self.get_class_names)] |
|
), |
|
keypoint_confidence_threshold=keypoint_confidence_threshold, |
|
), |
|
} |
|
) |
|
for pred in batch_predictions |
|
if not class_filter |
|
or self.class_names[int(pred[6])] in class_filter |
|
], |
|
image=InferenceResponseImage( |
|
width=img_dims[ind][1], height=img_dims[ind][0] |
|
), |
|
) |
|
for ind, batch_predictions in enumerate(predictions) |
|
] |
|
return responses |
|
|
|
def keypoints_count(self) -> int: |
|
raise NotImplementedError |
|
|
|
def validate_model_classes(self) -> None: |
|
num_keypoints = self.keypoints_count() |
|
output_shape = self.get_model_output_shape() |
|
num_classes = get_num_classes_from_model_prediction_shape( |
|
len_prediction=output_shape[2], keypoints=num_keypoints |
|
) |
|
if num_classes != self.num_classes: |
|
raise ValueError( |
|
f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" |
|
) |
|
|