File size: 4,251 Bytes
14c9181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.textdet.postprocessors.base import BaseTextDetPostProcessor
from mmocr.registry import MODELS
from ..utils import bezier2poly


@MODELS.register_module()
class ABCNetPostprocessor(BaseTextDetPostProcessor):
    """Post-processing methods for ABCNet.

    Args:
        num_classes (int): Number of classes.
        use_sigmoid_cls (bool): Whether to use sigmoid for classification.
        strides (tuple): Strides of each feature map.
        norm_by_strides (bool): Whether to normalize the regression targets by
            the strides.
        bbox_coder (dict): Config dict for bbox coder.
        text_repr_type (str): Text representation type, 'poly' or 'quad'.
        with_bezier (bool): Whether to use bezier curve for text detection.
        train_cfg (dict): Config dict for training.
        test_cfg (dict): Config dict for testing.
    """

    def __init__(
        self,
        text_repr_type='poly',
        rescale_fields=['beziers', 'polygons'],
    ):
        super().__init__(
            text_repr_type=text_repr_type, rescale_fields=rescale_fields)

    def merge_predict(self, spotting_data_samples, recog_data_samples):
        texts = [ds.pred_text.item for ds in recog_data_samples]
        start = 0
        for spotting_data_sample in spotting_data_samples:
            end = start + len(spotting_data_sample.pred_instances)
            spotting_data_sample.pred_instances.texts = texts[start:end]
            start = end
        return spotting_data_samples

    # TODO: fix docstr
    def __call__(self,
                 spotting_data_samples,
                 recog_data_samples,
                 training: bool = False):
        """Postprocess pred_results according to metainfos in data_samples.

        Args:
            pred_results (Union[Tensor, List[Tensor]]): The prediction results
                stored in a tensor or a list of tensor. Usually each item to
                be post-processed is expected to be a batched tensor.
            data_samples (list[TextDetDataSample]): Batch of data_samples,
                each corresponding to a prediction result.
            training (bool): Whether the model is in training mode. Defaults to
                False.

        Returns:
            list[TextDetDataSample]: Batch of post-processed datasamples.
        """
        spotting_data_samples = list(
            map(self._process_single, spotting_data_samples))
        return self.merge_predict(spotting_data_samples, recog_data_samples)

    def _process_single(self, data_sample):
        """Process prediction results from one image.

        Args:
            pred_result (Union[Tensor, List[Tensor]]): Prediction results of an
                image.
            data_sample (TextDetDataSample): Datasample of an image.
        """
        data_sample = self.get_text_instances(data_sample)
        if self.rescale_fields and len(self.rescale_fields) > 0:
            assert isinstance(self.rescale_fields, list)
            assert set(self.rescale_fields).issubset(
                set(data_sample.pred_instances.keys()))
            data_sample = self.rescale(data_sample, data_sample.scale_factor)
        return data_sample

    def get_text_instances(self, data_sample, **kwargs):
        """Get text instance predictions of one image.

        Args:
            pred_result (tuple(Tensor)): Prediction results of an image.
            data_sample (TextDetDataSample): Datasample of an image.
            **kwargs: Other parameters. Configurable via ``__init__.train_cfg``
                and ``__init__.test_cfg``.

        Returns:
            TextDetDataSample: A new DataSample with predictions filled in.
            The polygon/bbox results are usually saved in
            ``TextDetDataSample.pred_instances.polygons`` or
            ``TextDetDataSample.pred_instances.bboxes``. The confidence scores
            are saved in ``TextDetDataSample.pred_instances.scores``.
        """
        data_sample = data_sample.cpu().numpy()
        pred_instances = data_sample.pred_instances
        data_sample.pred_instances.polygons = list(
            map(bezier2poly, pred_instances.beziers))
        return data_sample