File size: 9,318 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any

from .detection.zoo import detection_predictor
from .kie_predictor import KIEPredictor
from .predictor import OCRPredictor
from .recognition.zoo import recognition_predictor

__all__ = ["ocr_predictor", "kie_predictor"]


def _predictor(
    det_arch: Any,
    reco_arch: Any,
    pretrained: bool,
    pretrained_backbone: bool = True,
    assume_straight_pages: bool = True,
    preserve_aspect_ratio: bool = True,
    symmetric_pad: bool = True,
    det_bs: int = 2,
    reco_bs: int = 128,
    detect_orientation: bool = False,
    straighten_pages: bool = False,
    detect_language: bool = False,
    **kwargs,
) -> OCRPredictor:
    # Detection
    det_predictor = detection_predictor(
        det_arch,
        pretrained=pretrained,
        pretrained_backbone=pretrained_backbone,
        batch_size=det_bs,
        assume_straight_pages=assume_straight_pages,
        preserve_aspect_ratio=preserve_aspect_ratio,
        symmetric_pad=symmetric_pad,
    )

    # Recognition
    reco_predictor = recognition_predictor(
        reco_arch,
        pretrained=pretrained,
        pretrained_backbone=pretrained_backbone,
        batch_size=reco_bs,
    )

    return OCRPredictor(
        det_predictor,
        reco_predictor,
        assume_straight_pages=assume_straight_pages,
        preserve_aspect_ratio=preserve_aspect_ratio,
        symmetric_pad=symmetric_pad,
        detect_orientation=detect_orientation,
        straighten_pages=straighten_pages,
        detect_language=detect_language,
        **kwargs,
    )


def ocr_predictor(
    det_arch: Any = "fast_base",
    reco_arch: Any = "crnn_vgg16_bn",
    pretrained: bool = False,
    pretrained_backbone: bool = True,
    assume_straight_pages: bool = True,
    preserve_aspect_ratio: bool = True,
    symmetric_pad: bool = True,
    export_as_straight_boxes: bool = False,
    detect_orientation: bool = False,
    straighten_pages: bool = False,
    detect_language: bool = False,
    **kwargs: Any,
) -> OCRPredictor:
    """End-to-end OCR architecture using one model for localization, and another for text recognition.

    >>> import numpy as np
    >>> from doctr.models import ocr_predictor
    >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True)
    >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
    >>> out = model([input_page])

    Args:
    ----
        det_arch: name of the detection architecture or the model itself to use
            (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
        reco_arch: name of the recognition architecture or the model itself to use
            (e.g. 'crnn_vgg16_bn', 'sar_resnet31')
        pretrained: If True, returns a model pre-trained on our OCR dataset
        pretrained_backbone: If True, returns a model with a pretrained backbone
        assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
            without rotated textual elements.
        preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
            running the detection model on it.
        symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right.
        export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
            (potentially rotated) as straight bounding boxes.
        detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
            page. Doing so will slightly deteriorate the overall latency.
        straighten_pages: if True, estimates the page general orientation
            based on the segmentation map median line orientation.
            Then, rotates page before passing it again to the deep learning detection module.
            Doing so will improve performances for documents with page-uniform rotations.
        detect_language: if True, the language prediction will be added to the predictions for each
            page. Doing so will slightly deteriorate the overall latency.
        kwargs: keyword args of `OCRPredictor`

    Returns:
    -------
        OCR predictor
    """
    return _predictor(
        det_arch,
        reco_arch,
        pretrained,
        pretrained_backbone=pretrained_backbone,
        assume_straight_pages=assume_straight_pages,
        preserve_aspect_ratio=preserve_aspect_ratio,
        symmetric_pad=symmetric_pad,
        export_as_straight_boxes=export_as_straight_boxes,
        detect_orientation=detect_orientation,
        straighten_pages=straighten_pages,
        detect_language=detect_language,
        **kwargs,
    )


def _kie_predictor(
    det_arch: Any,
    reco_arch: Any,
    pretrained: bool,
    pretrained_backbone: bool = True,
    assume_straight_pages: bool = True,
    preserve_aspect_ratio: bool = True,
    symmetric_pad: bool = True,
    det_bs: int = 2,
    reco_bs: int = 128,
    detect_orientation: bool = False,
    straighten_pages: bool = False,
    detect_language: bool = False,
    **kwargs,
) -> KIEPredictor:
    # Detection
    det_predictor = detection_predictor(
        det_arch,
        pretrained=pretrained,
        pretrained_backbone=pretrained_backbone,
        batch_size=det_bs,
        assume_straight_pages=assume_straight_pages,
        preserve_aspect_ratio=preserve_aspect_ratio,
        symmetric_pad=symmetric_pad,
    )

    # Recognition
    reco_predictor = recognition_predictor(
        reco_arch,
        pretrained=pretrained,
        pretrained_backbone=pretrained_backbone,
        batch_size=reco_bs,
    )

    return KIEPredictor(
        det_predictor,
        reco_predictor,
        assume_straight_pages=assume_straight_pages,
        preserve_aspect_ratio=preserve_aspect_ratio,
        symmetric_pad=symmetric_pad,
        detect_orientation=detect_orientation,
        straighten_pages=straighten_pages,
        detect_language=detect_language,
        **kwargs,
    )


def kie_predictor(
    det_arch: Any = "fast_base",
    reco_arch: Any = "crnn_vgg16_bn",
    pretrained: bool = False,
    pretrained_backbone: bool = True,
    assume_straight_pages: bool = True,
    preserve_aspect_ratio: bool = True,
    symmetric_pad: bool = True,
    export_as_straight_boxes: bool = False,
    detect_orientation: bool = False,
    straighten_pages: bool = False,
    detect_language: bool = False,
    **kwargs: Any,
) -> KIEPredictor:
    """End-to-end KIE architecture using one model for localization, and another for text recognition.

    >>> import numpy as np
    >>> from doctr.models import ocr_predictor
    >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True)
    >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
    >>> out = model([input_page])

    Args:
    ----
        det_arch: name of the detection architecture or the model itself to use
            (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
        reco_arch: name of the recognition architecture or the model itself to use
            (e.g. 'crnn_vgg16_bn', 'sar_resnet31')
        pretrained: If True, returns a model pre-trained on our OCR dataset
        pretrained_backbone: If True, returns a model with a pretrained backbone
        assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
            without rotated textual elements.
        preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
            running the detection model on it.
        symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right.
        export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
            (potentially rotated) as straight bounding boxes.
        detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
            page. Doing so will slightly deteriorate the overall latency.
        straighten_pages: if True, estimates the page general orientation
            based on the segmentation map median line orientation.
            Then, rotates page before passing it again to the deep learning detection module.
            Doing so will improve performances for documents with page-uniform rotations.
        detect_language: if True, the language prediction will be added to the predictions for each
            page. Doing so will slightly deteriorate the overall latency.
        kwargs: keyword args of `OCRPredictor`

    Returns:
    -------
        KIE predictor
    """
    return _kie_predictor(
        det_arch,
        reco_arch,
        pretrained,
        pretrained_backbone=pretrained_backbone,
        assume_straight_pages=assume_straight_pages,
        preserve_aspect_ratio=preserve_aspect_ratio,
        symmetric_pad=symmetric_pad,
        export_as_straight_boxes=export_as_straight_boxes,
        detect_orientation=detect_orientation,
        straighten_pages=straighten_pages,
        detect_language=detect_language,
        **kwargs,
    )