File size: 6,921 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import cv2
import torch
from mmdet.structures import DetDataSample
from mmdet.structures import SampleList as MMDET_SampleList
from mmdet.structures.mask import bitmap_to_polygon
from mmengine.model import BaseModel
from mmengine.structures import InstanceData

from mmocr.registry import MODELS
from mmocr.utils.bbox_utils import bbox2poly
from mmocr.utils.typing_utils import DetSampleList

ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample],
                       Tuple[torch.Tensor], torch.Tensor]


@MODELS.register_module()
class MMDetWrapper(BaseModel):
    """A wrapper of MMDet's model.

    Args:
        cfg (dict): The config of the model.
        text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
            Defaults to 'poly'.
    """

    def __init__(self, cfg: Dict, text_repr_type: str = 'poly') -> None:
        data_preprocessor = cfg.pop('data_preprocessor')
        data_preprocessor.update(_scope_='mmdet')
        super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
        cfg['_scope_'] = 'mmdet'
        self.wrapped_model = MODELS.build(cfg)
        self.text_repr_type = text_repr_type

    def forward(self,
                inputs: torch.Tensor,
                data_samples: Optional[Union[DetSampleList,
                                             MMDET_SampleList]] = None,
                mode: str = 'tensor',
                **kwargs) -> ForwardResults:
        """The unified entry for a forward process in both training and test.

        The method works in three modes: "tensor", "predict" and "loss":

        - "tensor": Forward the whole network and return tensor or tuple of
        tensor without any post-processing, same as a common nn.Module.
        - "predict": Forward and return the predictions, which are fully
        processed to a list of :obj:`DetDataSample`.
        - "loss": Forward and return a dict of losses according to the given
        inputs and data samples.

        Note that this method doesn't handle either back propagation or
        parameter update, which are supposed to be done in :meth:`train_step`.

        Args:
            inputs (torch.Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (list[:obj:`DetDataSample`] or
                list[:obj:`TextDetDataSample`]): The annotation data of every
                sample. When in "predict" mode, it should be a list of
                :obj:`TextDetDataSample`. Otherwise they are
                :obj:`DetDataSample`s. Defaults to None.
            mode (str): Running mode. Defaults to 'tensor'.

        Returns:
            The return type depends on ``mode``.

            - If ``mode="tensor"``, return a tensor or a tuple of tensor.
            - If ``mode="predict"``, return a list of :obj:`TextDetDataSample`.
            - If ``mode="loss"``, return a dict of tensor.
        """
        if mode == 'predict':
            ocr_data_samples = data_samples
            data_samples = []
            for i in range(len(ocr_data_samples)):
                data_samples.append(
                    DetDataSample(metainfo=ocr_data_samples[i].metainfo))

        results = self.wrapped_model.forward(inputs, data_samples, mode,
                                             **kwargs)

        if mode == 'predict':
            results = self.adapt_predictions(results, ocr_data_samples)

        return results

    def adapt_predictions(self, data: MMDET_SampleList,
                          data_samples: DetSampleList) -> DetSampleList:
        """Convert Instance datas from MMDet into MMOCR's format.

        Args:
            data: (list[DetDataSample]): Detection results of the
                input images. Each DetDataSample usually contain
                'pred_instances'. And the ``pred_instances`` usually
                contains following keys.
                - scores (Tensor): Classification scores, has a shape
                    (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes (Tensor): Has a shape (num_instances, 4),
                    the last dimension 4 arrange as (x1, y1, x2, y2).
                - masks (Tensor, Optional): Has a shape (num_instances, H, W).
            data_samples (list[:obj:`TextDetDataSample`]): The annotation data
                of every samples.

        Returns:
            list[TextDetDataSample]: A list of N datasamples containing ground
                truth and prediction results.
                The polygon results are saved in
                ``TextDetDataSample.pred_instances.polygons``
                The confidence scores are saved in
                ``TextDetDataSample.pred_instances.scores``.
        """
        for i, det_data_sample in enumerate(data):
            data_samples[i].pred_instances = InstanceData()
            # convert mask to polygons if mask exists
            if 'masks' in det_data_sample.pred_instances.keys():
                masks = det_data_sample.pred_instances.masks.cpu().numpy()
                polygons = []
                scores = []
                for mask_idx, mask in enumerate(masks):
                    contours, _ = bitmap_to_polygon(mask)
                    polygons += [contour.reshape(-1) for contour in contours]
                    scores += [
                        det_data_sample.pred_instances.scores[mask_idx].cpu()
                    ] * len(contours)
                # filter invalid polygons
                filterd_polygons = []
                keep_idx = []
                for poly_idx, polygon in enumerate(polygons):
                    if len(polygon) < 6:
                        continue
                    filterd_polygons.append(polygon)
                    keep_idx.append(poly_idx)
                # convert by text_repr_type
                if self.text_repr_type == 'quad':
                    for j, poly in enumerate(filterd_polygons):
                        rect = cv2.minAreaRect(poly)
                        vertices = cv2.boxPoints(rect)
                        poly = vertices.flatten()
                        filterd_polygons[j] = poly

                data_samples[i].pred_instances.polygons = filterd_polygons
                data_samples[i].pred_instances.scores = torch.FloatTensor(
                    scores)[keep_idx]
            else:
                bboxes = det_data_sample.pred_instances.bboxes.cpu().numpy()
                polygons = [bbox2poly(bbox) for bbox in bboxes]
                data_samples[i].pred_instances.polygons = polygons
                data_samples[i].pred_instances.scores = torch.FloatTensor(
                    det_data_sample.pred_instances.scores.cpu())

        return data_samples