Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
from typing import Any, List, Optional | |
import numpy as np | |
from doctr.file_utils import requires_package | |
from doctr.utils.data import download_from_url | |
class _BasePredictor: | |
""" | |
Base class for all predictors | |
Args: | |
---- | |
batch_size: the batch size to use | |
url: the url to use to download a model if needed | |
model_path: the path to the model to use | |
**kwargs: additional arguments to be passed to `download_from_url` | |
""" | |
def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None: | |
self.batch_size = batch_size | |
self.session = self._init_model(url, model_path, **kwargs) | |
self._inputs: List[np.ndarray] = [] | |
self._results: List[Any] = [] | |
def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any: | |
""" | |
Download the model from the given url if needed | |
Args: | |
---- | |
url: the url to use | |
model_path: the path to the model to use | |
**kwargs: additional arguments to be passed to `download_from_url` | |
Returns: | |
------- | |
Any: the ONNX loaded model | |
""" | |
requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.") | |
import onnxruntime as ort | |
if not url and not model_path: | |
raise ValueError("You must provide either a url or a model_path") | |
onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type] | |
return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) | |
def preprocess(self, img: np.ndarray) -> np.ndarray: | |
""" | |
Preprocess the input image | |
Args: | |
---- | |
img: the input image to preprocess | |
Returns: | |
------- | |
np.ndarray: the preprocessed image | |
""" | |
raise NotImplementedError | |
def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any: | |
""" | |
Postprocess the model output | |
Args: | |
---- | |
output: the model output to postprocess | |
input_images: the input images used to generate the output | |
Returns: | |
------- | |
Any: the postprocessed output | |
""" | |
raise NotImplementedError | |
def __call__(self, inputs: List[np.ndarray]) -> Any: | |
""" | |
Call the model on the given inputs | |
Args: | |
---- | |
inputs: the inputs to use | |
Returns: | |
------- | |
Any: the postprocessed output | |
""" | |
self._inputs = inputs | |
model_inputs = self.session.get_inputs() | |
batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)] | |
processed_batches = [ | |
np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs | |
] | |
outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches] | |
return self.postprocess(outputs, batched_inputs) | |