|
from io import BytesIO |
|
from time import perf_counter |
|
from typing import Any, List, Tuple, Union |
|
|
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
from inference.core.entities.requests.inference import ClassificationInferenceRequest |
|
from inference.core.entities.responses.inference import ( |
|
ClassificationInferenceResponse, |
|
InferenceResponse, |
|
InferenceResponseImage, |
|
MultiLabelClassificationInferenceResponse, |
|
) |
|
from inference.core.models.roboflow import OnnxRoboflowInferenceModel |
|
from inference.core.models.types import PreprocessReturnMetadata |
|
from inference.core.models.utils.validate import ( |
|
get_num_classes_from_model_prediction_shape, |
|
) |
|
from inference.core.utils.image_utils import load_image_rgb |
|
|
|
|
|
class ClassificationBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel): |
|
"""Base class for ONNX models for Roboflow classification inference. |
|
|
|
Attributes: |
|
multiclass (bool): Whether the classification is multi-class or not. |
|
|
|
Methods: |
|
get_infer_bucket_file_list() -> list: Get the list of required files for inference. |
|
softmax(x): Compute softmax values for a given set of scores. |
|
infer(request: ClassificationInferenceRequest) -> Union[List[Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]], Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]]: Perform inference on a given request and return the response. |
|
draw_predictions(inference_request, inference_response): Draw prediction visuals on an image. |
|
""" |
|
|
|
task_type = "classification" |
|
|
|
def __init__(self, *args, **kwargs): |
|
"""Initialize the model, setting whether it is multiclass or not.""" |
|
super().__init__(*args, **kwargs) |
|
self.multiclass = self.environment.get("MULTICLASS", False) |
|
|
|
def draw_predictions(self, inference_request, inference_response): |
|
"""Draw prediction visuals on an image. |
|
|
|
This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes. |
|
|
|
Args: |
|
inference_request: The request object containing the image and parameters. |
|
inference_response: The response object containing the predictions and other details. |
|
|
|
Returns: |
|
bytes: The bytes of the visualized image in JPEG format. |
|
""" |
|
image = load_image_rgb(inference_request.image) |
|
image = Image.fromarray(image) |
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.load_default() |
|
if isinstance(inference_response.predictions, list): |
|
prediction = inference_response.predictions[0] |
|
color = self.colors.get(prediction.class_name, "#4892EA") |
|
draw.rectangle( |
|
[0, 0, image.size[1], image.size[0]], |
|
outline=color, |
|
width=inference_request.visualization_stroke_width, |
|
) |
|
text = f"{prediction.class_id} - {prediction.class_name} {prediction.confidence:.2f}" |
|
text_size = font.getbbox(text) |
|
|
|
|
|
button_size = (text_size[2] + 20, text_size[3] + 20) |
|
button_img = Image.new("RGBA", button_size, color) |
|
|
|
button_draw = ImageDraw.Draw(button_img) |
|
button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255)) |
|
|
|
|
|
image.paste(button_img, (0, 0)) |
|
else: |
|
if len(inference_response.predictions) > 0: |
|
box_color = "#4892EA" |
|
draw.rectangle( |
|
[0, 0, image.size[1], image.size[0]], |
|
outline=box_color, |
|
width=inference_request.visualization_stroke_width, |
|
) |
|
row = 0 |
|
predictions = [ |
|
(cls_name, pred) |
|
for cls_name, pred in inference_response.predictions.items() |
|
] |
|
predictions = sorted( |
|
predictions, key=lambda x: x[1].confidence, reverse=True |
|
) |
|
for i, (cls_name, pred) in enumerate(predictions): |
|
color = self.colors.get(cls_name, "#4892EA") |
|
text = f"{cls_name} {pred.confidence:.2f}" |
|
text_size = font.getbbox(text) |
|
|
|
|
|
button_size = (text_size[2] + 20, text_size[3] + 20) |
|
button_img = Image.new("RGBA", button_size, color) |
|
|
|
button_draw = ImageDraw.Draw(button_img) |
|
button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255)) |
|
|
|
|
|
image.paste(button_img, (0, row)) |
|
row += button_size[1] |
|
|
|
buffered = BytesIO() |
|
image = image.convert("RGB") |
|
image.save(buffered, format="JPEG") |
|
return buffered.getvalue() |
|
|
|
def get_infer_bucket_file_list(self) -> list: |
|
"""Get the list of required files for inference. |
|
|
|
Returns: |
|
list: A list of required files for inference, e.g., ["environment.json"]. |
|
""" |
|
return ["environment.json"] |
|
|
|
def infer( |
|
self, |
|
image: Any, |
|
disable_preproc_auto_orient: bool = False, |
|
disable_preproc_contrast: bool = False, |
|
disable_preproc_grayscale: bool = False, |
|
disable_preproc_static_crop: bool = False, |
|
return_image_dims: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
Perform inference on the provided image(s) and return the predictions. |
|
|
|
Args: |
|
image (Any): The image or list of images to be processed. |
|
disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False. |
|
disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False. |
|
disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. |
|
disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. |
|
return_image_dims (bool, optional): If set to True, the function will also return the dimensions of the image. Defaults to False. |
|
**kwargs: Additional parameters to customize the inference process. |
|
|
|
Returns: |
|
Union[List[np.array], np.array, Tuple[List[np.array], List[Tuple[int, int]]], Tuple[np.array, Tuple[int, int]]]: |
|
If `return_image_dims` is True and a list of images is provided, a tuple containing a list of prediction arrays and a list of image dimensions (width, height) is returned. |
|
If `return_image_dims` is True and a single image is provided, a tuple containing the prediction array and image dimensions (width, height) is returned. |
|
If `return_image_dims` is False and a list of images is provided, only the list of prediction arrays is returned. |
|
If `return_image_dims` is False and a single image is provided, only the prediction array is returned. |
|
|
|
Notes: |
|
- The input image(s) will be preprocessed (normalized and reshaped) before inference. |
|
- This function uses an ONNX session to perform inference on the input image(s). |
|
""" |
|
return super().infer( |
|
image, |
|
disable_preproc_auto_orient=disable_preproc_auto_orient, |
|
disable_preproc_contrast=disable_preproc_contrast, |
|
disable_preproc_grayscale=disable_preproc_grayscale, |
|
disable_preproc_static_crop=disable_preproc_static_crop, |
|
return_image_dims=return_image_dims, |
|
) |
|
|
|
def postprocess( |
|
self, |
|
predictions: Tuple[np.ndarray], |
|
preprocess_return_metadata: PreprocessReturnMetadata, |
|
return_image_dims=False, |
|
**kwargs, |
|
) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: |
|
predictions = predictions[0] |
|
return self.make_response( |
|
predictions, preprocess_return_metadata["img_dims"], **kwargs |
|
) |
|
|
|
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: |
|
predictions = self.onnx_session.run(None, {self.input_name: img_in}) |
|
return (predictions,) |
|
|
|
def preprocess( |
|
self, image: Any, **kwargs |
|
) -> Tuple[np.ndarray, PreprocessReturnMetadata]: |
|
if isinstance(image, list): |
|
imgs_with_dims = [ |
|
self.preproc_image( |
|
i, |
|
disable_preproc_auto_orient=kwargs.get( |
|
"disable_preproc_auto_orient", False |
|
), |
|
disable_preproc_contrast=kwargs.get( |
|
"disable_preproc_contrast", False |
|
), |
|
disable_preproc_grayscale=kwargs.get( |
|
"disable_preproc_grayscale", False |
|
), |
|
disable_preproc_static_crop=kwargs.get( |
|
"disable_preproc_static_crop", False |
|
), |
|
) |
|
for i in image |
|
] |
|
imgs, img_dims = zip(*imgs_with_dims) |
|
img_in = np.concatenate(imgs, axis=0) |
|
else: |
|
img_in, img_dims = self.preproc_image( |
|
image, |
|
disable_preproc_auto_orient=kwargs.get( |
|
"disable_preproc_auto_orient", False |
|
), |
|
disable_preproc_contrast=kwargs.get("disable_preproc_contrast", False), |
|
disable_preproc_grayscale=kwargs.get( |
|
"disable_preproc_grayscale", False |
|
), |
|
disable_preproc_static_crop=kwargs.get( |
|
"disable_preproc_static_crop", False |
|
), |
|
) |
|
img_dims = [img_dims] |
|
|
|
img_in /= 255.0 |
|
|
|
mean = (0.5, 0.5, 0.5) |
|
std = (0.5, 0.5, 0.5) |
|
|
|
img_in = img_in.astype(np.float32) |
|
|
|
img_in[:, 0, :, :] = (img_in[:, 0, :, :] - mean[0]) / std[0] |
|
img_in[:, 1, :, :] = (img_in[:, 1, :, :] - mean[1]) / std[1] |
|
img_in[:, 2, :, :] = (img_in[:, 2, :, :] - mean[2]) / std[2] |
|
return img_in, PreprocessReturnMetadata({"img_dims": img_dims}) |
|
|
|
def infer_from_request( |
|
self, |
|
request: ClassificationInferenceRequest, |
|
) -> Union[List[InferenceResponse], InferenceResponse]: |
|
""" |
|
Handle an inference request to produce an appropriate response. |
|
|
|
Args: |
|
request (ClassificationInferenceRequest): The request object encapsulating the image(s) and relevant parameters. |
|
|
|
Returns: |
|
Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned. |
|
|
|
Notes: |
|
- Starts a timer at the beginning to calculate inference time. |
|
- Processes the image(s) through the `infer` method. |
|
- Generates the appropriate response object(s) using `make_response`. |
|
- Calculates and sets the time taken for inference. |
|
- If visualization is requested, the predictions are drawn on the image. |
|
""" |
|
t1 = perf_counter() |
|
responses = self.infer(**request.dict(), return_image_dims=True) |
|
for response in responses: |
|
response.time = perf_counter() - t1 |
|
|
|
if request.visualize_predictions: |
|
for response in responses: |
|
response.visualization = self.draw_predictions(request, response) |
|
|
|
if not isinstance(request.image, list): |
|
responses = responses[0] |
|
|
|
return responses |
|
|
|
def make_response( |
|
self, |
|
predictions, |
|
img_dims, |
|
confidence: float = 0.5, |
|
**kwargs, |
|
) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: |
|
""" |
|
Create response objects for the given predictions and image dimensions. |
|
|
|
Args: |
|
predictions (list): List of prediction arrays from the inference process. |
|
img_dims (list): List of tuples indicating the dimensions (width, height) of each image. |
|
confidence (float, optional): Confidence threshold for filtering predictions. Defaults to 0.5. |
|
**kwargs: Additional parameters to influence the response creation process. |
|
|
|
Returns: |
|
Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: A response object or a list of response objects encapsulating the prediction details. |
|
|
|
Notes: |
|
- If the model is multiclass, a `MultiLabelClassificationInferenceResponse` is generated for each image. |
|
- If the model is not multiclass, a `ClassificationInferenceResponse` is generated for each image. |
|
- Predictions below the confidence threshold are filtered out. |
|
""" |
|
responses = [] |
|
confidence_threshold = float(confidence) |
|
for ind, prediction in enumerate(predictions): |
|
if self.multiclass: |
|
preds = prediction[0] |
|
results = dict() |
|
predicted_classes = [] |
|
for i, o in enumerate(preds): |
|
cls_name = self.class_names[i] |
|
score = float(o) |
|
results[cls_name] = {"confidence": score, "class_id": i} |
|
if score > confidence_threshold: |
|
predicted_classes.append(cls_name) |
|
response = MultiLabelClassificationInferenceResponse( |
|
image=InferenceResponseImage( |
|
width=img_dims[ind][0], height=img_dims[ind][1] |
|
), |
|
predicted_classes=predicted_classes, |
|
predictions=results, |
|
) |
|
else: |
|
preds = prediction[0] |
|
preds = self.softmax(preds) |
|
results = [] |
|
for i, cls_name in enumerate(self.class_names): |
|
score = float(preds[i]) |
|
pred = { |
|
"class_id": i, |
|
"class": cls_name, |
|
"confidence": round(score, 4), |
|
} |
|
results.append(pred) |
|
results = sorted(results, key=lambda x: x["confidence"], reverse=True) |
|
|
|
response = ClassificationInferenceResponse( |
|
image=InferenceResponseImage( |
|
width=img_dims[ind][1], height=img_dims[ind][0] |
|
), |
|
predictions=results, |
|
top=results[0]["class"], |
|
confidence=results[0]["confidence"], |
|
) |
|
responses.append(response) |
|
|
|
return responses |
|
|
|
@staticmethod |
|
def softmax(x): |
|
"""Compute softmax values for each set of scores in x. |
|
|
|
Args: |
|
x (np.array): The input array containing the scores. |
|
|
|
Returns: |
|
np.array: The softmax values for each set of scores. |
|
""" |
|
e_x = np.exp(x - np.max(x)) |
|
return e_x / e_x.sum() |
|
|
|
def get_model_output_shape(self) -> Tuple[int, int, int]: |
|
test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8) |
|
test_image, _ = self.preprocess(test_image) |
|
output = np.array(self.predict(test_image)) |
|
return output.shape |
|
|
|
def validate_model_classes(self) -> None: |
|
output_shape = self.get_model_output_shape() |
|
num_classes = output_shape[3] |
|
try: |
|
assert num_classes == self.num_classes |
|
except AssertionError: |
|
raise ValueError( |
|
f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" |
|
) |
|
|