File size: 3,426 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List, Optional

import numpy as np

from doctr.file_utils import requires_package
from doctr.utils.data import download_from_url


class _BasePredictor:
    """
    Base class for all predictors

    Args:
    ----
        batch_size: the batch size to use
        url: the url to use to download a model if needed
        model_path: the path to the model to use
        **kwargs: additional arguments to be passed to `download_from_url`
    """

    def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
        self.batch_size = batch_size
        self.session = self._init_model(url, model_path, **kwargs)

        self._inputs: List[np.ndarray] = []
        self._results: List[Any] = []

    def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
        """
        Download the model from the given url if needed

        Args:
        ----
            url: the url to use
            model_path: the path to the model to use
            **kwargs: additional arguments to be passed to `download_from_url`

        Returns:
        -------
            Any: the ONNX loaded model
        """
        requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
        import onnxruntime as ort

        if not url and not model_path:
            raise ValueError("You must provide either a url or a model_path")
        onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs))  # type: ignore[arg-type]
        return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

    def preprocess(self, img: np.ndarray) -> np.ndarray:
        """
        Preprocess the input image

        Args:
        ----
            img: the input image to preprocess

        Returns:
        -------
            np.ndarray: the preprocessed image
        """
        raise NotImplementedError

    def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
        """
        Postprocess the model output

        Args:
        ----
            output: the model output to postprocess
            input_images: the input images used to generate the output

        Returns:
        -------
            Any: the postprocessed output
        """
        raise NotImplementedError

    def __call__(self, inputs: List[np.ndarray]) -> Any:
        """
        Call the model on the given inputs

        Args:
        ----
            inputs: the inputs to use

        Returns:
        -------
            Any: the postprocessed output
        """
        self._inputs = inputs
        model_inputs = self.session.get_inputs()

        batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)]
        processed_batches = [
            np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs
        ]

        outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches]
        return self.postprocess(outputs, batched_inputs)