File size: 7,355 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, Dict, List, Union

import numpy as np
import tensorflow as tf

from doctr.io.elements import Document
from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.utils.geometry import rotate_image
from doctr.utils.repr import NestedObject

from .base import _KIEPredictor

__all__ = ["KIEPredictor"]


class KIEPredictor(NestedObject, _KIEPredictor):
    """Implements an object able to localize and identify text elements in a set of documents

    Args:
    ----
        det_predictor: detection module
        reco_predictor: recognition module
        assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
            without rotated textual elements.
        straighten_pages: if True, estimates the page general orientation based on the median line orientation.
            Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
            accordingly. Doing so will improve performances for documents with page-uniform rotations.
        detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
            page. Doing so will slightly deteriorate the overall latency.
        detect_language: if True, the language prediction will be added to the predictions for each
            page. Doing so will slightly deteriorate the overall latency.
        **kwargs: keyword args of `DocumentBuilder`
    """

    _children_names = ["det_predictor", "reco_predictor", "doc_builder"]

    def __init__(
        self,
        det_predictor: DetectionPredictor,
        reco_predictor: RecognitionPredictor,
        assume_straight_pages: bool = True,
        straighten_pages: bool = False,
        preserve_aspect_ratio: bool = True,
        symmetric_pad: bool = True,
        detect_orientation: bool = False,
        detect_language: bool = False,
        **kwargs: Any,
    ) -> None:
        self.det_predictor = det_predictor
        self.reco_predictor = reco_predictor
        _KIEPredictor.__init__(
            self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
        )
        self.detect_orientation = detect_orientation
        self.detect_language = detect_language

    def __call__(
        self,
        pages: List[Union[np.ndarray, tf.Tensor]],
        **kwargs: Any,
    ) -> Document:
        # Dimension check
        if any(page.ndim != 3 for page in pages):
            raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")

        origin_page_shapes = [page.shape[:2] for page in pages]

        # Localize text elements
        loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)

        # Detect document rotation and rotate pages
        seg_maps = [
            np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
                np.uint8
            )
            for out_map in out_maps
        ]
        if self.detect_orientation:
            origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
            orientations = [
                {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
            ]
        else:
            orientations = None
        if self.straighten_pages:
            origin_page_orientations = (
                origin_page_orientations
                if self.detect_orientation
                else [estimate_orientation(seq_map) for seq_map in seg_maps]
            )
            pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
            # Forward again to get predictions on straight pages
            loc_preds = self.det_predictor(pages, **kwargs)  # type: ignore[assignment]

        dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds)  # type: ignore
        # Rectify crops if aspect ratio
        dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}

        # Apply hooks to loc_preds if any
        for hook in self.hooks:
            dict_loc_preds = hook(dict_loc_preds)

        # Crop images
        crops = {}
        for class_name in dict_loc_preds.keys():
            crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
                pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
            )

        # Rectify crop orientation
        crop_orientations: Any = {}
        if not self.assume_straight_pages:
            for class_name in dict_loc_preds.keys():
                crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
                    crops[class_name], dict_loc_preds[class_name]
                )
                crop_orientations[class_name] = [
                    {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
                ]

        # Identify character sequences
        word_preds = {
            k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
            for k, crop_value in crops.items()
        }
        if not crop_orientations:
            crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}

        boxes: Dict = {}
        text_preds: Dict = {}
        word_crop_orientations: Dict = {}
        for class_name in dict_loc_preds.keys():
            boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
                dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
            )

        boxes_per_page: List[Dict] = invert_data_structure(boxes)  # type: ignore[assignment]
        text_preds_per_page: List[Dict] = invert_data_structure(text_preds)  # type: ignore[assignment]
        crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations)  # type: ignore[assignment]

        if self.detect_language:
            languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
            languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
        else:
            languages_dict = None

        out = self.doc_builder(
            pages,
            boxes_per_page,
            text_preds_per_page,
            origin_page_shapes,  # type: ignore[arg-type]
            crop_orientations_per_page,
            orientations,
            languages_dict,
        )
        return out

    @staticmethod
    def get_text(text_pred: Dict) -> str:
        text = []
        for value in text_pred.values():
            text += [item[0] for item in value]

        return " ".join(text)