# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from typing import Any, List, Optional import numpy as np from doctr.file_utils import requires_package from doctr.utils.data import download_from_url class _BasePredictor: """ Base class for all predictors Args: ---- batch_size: the batch size to use url: the url to use to download a model if needed model_path: the path to the model to use **kwargs: additional arguments to be passed to `download_from_url` """ def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None: self.batch_size = batch_size self.session = self._init_model(url, model_path, **kwargs) self._inputs: List[np.ndarray] = [] self._results: List[Any] = [] def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any: """ Download the model from the given url if needed Args: ---- url: the url to use model_path: the path to the model to use **kwargs: additional arguments to be passed to `download_from_url` Returns: ------- Any: the ONNX loaded model """ requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.") import onnxruntime as ort if not url and not model_path: raise ValueError("You must provide either a url or a model_path") onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type] return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) def preprocess(self, img: np.ndarray) -> np.ndarray: """ Preprocess the input image Args: ---- img: the input image to preprocess Returns: ------- np.ndarray: the preprocessed image """ raise NotImplementedError def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any: """ Postprocess the model output Args: ---- output: the model output to postprocess input_images: the input images used to generate the output Returns: ------- Any: the postprocessed output """ raise NotImplementedError def __call__(self, inputs: List[np.ndarray]) -> Any: """ Call the model on the given inputs Args: ---- inputs: the inputs to use Returns: ------- Any: the postprocessed output """ self._inputs = inputs model_inputs = self.session.get_inputs() batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)] processed_batches = [ np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs ] outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches] return self.postprocess(outputs, batched_inputs)