|
import json |
|
from typing import List, Tuple |
|
|
|
from inference.core.active_learning.entities import ( |
|
Prediction, |
|
PredictionFileType, |
|
PredictionType, |
|
SerialisedPrediction, |
|
) |
|
from inference.core.constants import ( |
|
CLASSIFICATION_TASK, |
|
INSTANCE_SEGMENTATION_TASK, |
|
OBJECT_DETECTION_TASK, |
|
) |
|
from inference.core.exceptions import PredictionFormatNotSupported |
|
|
|
|
|
def adjust_prediction_to_client_scaling_factor( |
|
prediction: dict, scaling_factor: float, prediction_type: PredictionType |
|
) -> dict: |
|
if abs(scaling_factor - 1.0) < 1e-5: |
|
return prediction |
|
if "image" in prediction: |
|
prediction["image"] = { |
|
"width": round(prediction["image"]["width"] / scaling_factor), |
|
"height": round(prediction["image"]["height"] / scaling_factor), |
|
} |
|
if predictions_should_not_be_post_processed( |
|
prediction=prediction, prediction_type=prediction_type |
|
): |
|
return prediction |
|
if prediction_type == INSTANCE_SEGMENTATION_TASK: |
|
prediction["predictions"] = ( |
|
adjust_prediction_with_bbox_and_points_to_client_scaling_factor( |
|
predictions=prediction["predictions"], |
|
scaling_factor=scaling_factor, |
|
points_key="points", |
|
) |
|
) |
|
if prediction_type == OBJECT_DETECTION_TASK: |
|
prediction["predictions"] = ( |
|
adjust_object_detection_predictions_to_client_scaling_factor( |
|
predictions=prediction["predictions"], |
|
scaling_factor=scaling_factor, |
|
) |
|
) |
|
return prediction |
|
|
|
|
|
def predictions_should_not_be_post_processed( |
|
prediction: dict, prediction_type: PredictionType |
|
) -> bool: |
|
|
|
return ( |
|
"is_stub" in prediction |
|
or "predictions" not in prediction |
|
or CLASSIFICATION_TASK in prediction_type |
|
or len(prediction["predictions"]) == 0 |
|
) |
|
|
|
|
|
def adjust_object_detection_predictions_to_client_scaling_factor( |
|
predictions: List[dict], |
|
scaling_factor: float, |
|
) -> List[dict]: |
|
result = [] |
|
for prediction in predictions: |
|
prediction = adjust_bbox_coordinates_to_client_scaling_factor( |
|
bbox=prediction, |
|
scaling_factor=scaling_factor, |
|
) |
|
result.append(prediction) |
|
return result |
|
|
|
|
|
def adjust_prediction_with_bbox_and_points_to_client_scaling_factor( |
|
predictions: List[dict], |
|
scaling_factor: float, |
|
points_key: str, |
|
) -> List[dict]: |
|
result = [] |
|
for prediction in predictions: |
|
prediction = adjust_bbox_coordinates_to_client_scaling_factor( |
|
bbox=prediction, |
|
scaling_factor=scaling_factor, |
|
) |
|
prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor( |
|
points=prediction[points_key], |
|
scaling_factor=scaling_factor, |
|
) |
|
result.append(prediction) |
|
return result |
|
|
|
|
|
def adjust_bbox_coordinates_to_client_scaling_factor( |
|
bbox: dict, |
|
scaling_factor: float, |
|
) -> dict: |
|
bbox["x"] = bbox["x"] / scaling_factor |
|
bbox["y"] = bbox["y"] / scaling_factor |
|
bbox["width"] = bbox["width"] / scaling_factor |
|
bbox["height"] = bbox["height"] / scaling_factor |
|
return bbox |
|
|
|
|
|
def adjust_points_coordinates_to_client_scaling_factor( |
|
points: List[dict], |
|
scaling_factor: float, |
|
) -> List[dict]: |
|
result = [] |
|
for point in points: |
|
point["x"] = point["x"] / scaling_factor |
|
point["y"] = point["y"] / scaling_factor |
|
result.append(point) |
|
return result |
|
|
|
|
|
def encode_prediction( |
|
prediction: Prediction, |
|
prediction_type: PredictionType, |
|
) -> Tuple[SerialisedPrediction, PredictionFileType]: |
|
if CLASSIFICATION_TASK not in prediction_type: |
|
return json.dumps(prediction), "json" |
|
if "top" in prediction: |
|
return prediction["top"], "txt" |
|
raise PredictionFormatNotSupported( |
|
f"Prediction type or prediction format not supported." |
|
) |
|
|