File size: 6,599 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence

import cv2
import numpy as np
import torch
from mmcv.ops import pixel_group
from mmengine.structures import InstanceData

from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from .base import BaseTextDetPostProcessor


@MODELS.register_module()
class PANPostprocessor(BaseTextDetPostProcessor):
    """Convert scores to quadrangles via post processing in PANet. This is
    partially adapted from https://github.com/WenmuZhou/PAN.pytorch.

    Args:
        text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
            Defaults to 'poly'.
        score_threshold (float): The minimal text score.
            Defaults to 0.3.
        rescale_fields (list[str]): The bbox/polygon field names to
            be rescaled. If None, no rescaling will be performed. Defaults to
            ['polygons'].
        min_text_confidence (float): The minimal text confidence.
            Defaults to 0.5.
        min_kernel_confidence (float): The minimal kernel confidence.
            Defaults to 0.5.
        distance_threshold (float): The minimal distance between the point to
            mean of text kernel. Defaults to 3.0.
        min_text_area (int): The minimal text instance region area.
            Defaults to 16.
        downsample_ratio (float): Downsample ratio. Defaults to 0.25.
    """

    def __init__(self,
                 text_repr_type: str = 'poly',
                 score_threshold: float = 0.3,
                 rescale_fields: Sequence[str] = ['polygons'],
                 min_text_confidence: float = 0.5,
                 min_kernel_confidence: float = 0.5,
                 distance_threshold: float = 3.0,
                 min_text_area: int = 16,
                 downsample_ratio: float = 0.25) -> None:
        super().__init__(text_repr_type, rescale_fields)

        self.min_text_confidence = min_text_confidence
        self.min_kernel_confidence = min_kernel_confidence
        self.score_threshold = score_threshold
        self.min_text_area = min_text_area
        self.distance_threshold = distance_threshold
        self.downsample_ratio = downsample_ratio

    def get_text_instances(self, pred_results: torch.Tensor,
                           data_sample: TextDetDataSample,
                           **kwargs) -> TextDetDataSample:
        """Get text instance predictions of one image.

        Args:
            pred_result (torch.Tensor): Prediction results of an image which
                is a tensor of shape :math:`(N, H, W)`.
            data_sample (TextDetDataSample): Datasample of an image.

        Returns:
            TextDetDataSample: A new DataSample with predictions filled in.
            Polygons and results are saved in
            ``TextDetDataSample.pred_instances.polygons``. The confidence
            scores are saved in ``TextDetDataSample.pred_instances.scores``.
        """
        assert pred_results.dim() == 3

        pred_results[:2, :, :] = torch.sigmoid(pred_results[:2, :, :])
        pred_results = pred_results.detach().cpu().numpy()

        text_score = pred_results[0].astype(np.float32)
        text = pred_results[0] > self.min_text_confidence
        kernel = (pred_results[1] > self.min_kernel_confidence) * text
        embeddings = pred_results[2:] * text.astype(np.float32)
        embeddings = embeddings.transpose((1, 2, 0))  # (h, w, 4)

        region_num, labels = cv2.connectedComponents(
            kernel.astype(np.uint8), connectivity=4)
        contours, _ = cv2.findContours((kernel * 255).astype(np.uint8),
                                       cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
        kernel_contours = np.zeros(text.shape, dtype='uint8')
        cv2.drawContours(kernel_contours, contours, -1, 255)
        text_points = pixel_group(text_score, text, embeddings, labels,
                                  kernel_contours, region_num,
                                  self.distance_threshold)

        polygons = []
        scores = []
        for text_point in text_points:
            text_confidence = text_point[0]
            text_point = text_point[2:]
            text_point = np.array(text_point, dtype=int).reshape(-1, 2)
            area = text_point.shape[0]
            if (area < self.min_text_area
                    or text_confidence <= self.score_threshold):
                continue

            polygon = self._points2boundary(text_point)
            if len(polygon) > 0:
                polygons.append(polygon)
                scores.append(text_confidence)
        pred_instances = InstanceData()
        pred_instances.polygons = polygons
        pred_instances.scores = torch.FloatTensor(scores)
        data_sample.pred_instances = pred_instances
        scale_factor = data_sample.scale_factor
        scale_factor = tuple(factor * self.downsample_ratio
                             for factor in scale_factor)
        data_sample.set_metainfo(dict(scale_factor=scale_factor))
        return data_sample

    def _points2boundary(self,
                         points: np.ndarray,
                         min_width: int = 0) -> List[float]:
        """Convert a text mask represented by point coordinates sequence into a
        text boundary.

        Args:
            points (ndarray): Mask index of size (n, 2).
            min_width (int): Minimum bounding box width to be converted. Only
                applicable to 'quad' type. Defaults to 0.

        Returns:
            list[float]: The text boundary point coordinates (x, y) list.
            Return [] if no text boundary found.
        """
        assert isinstance(points, np.ndarray)
        assert points.shape[1] == 2
        assert self.text_repr_type in ['quad', 'poly']

        if self.text_repr_type == 'quad':
            rect = cv2.minAreaRect(points)
            vertices = cv2.boxPoints(rect)
            boundary = []
            if min(rect[1]) >= min_width:
                boundary = [p for p in vertices.flatten().tolist()]
        elif self.text_repr_type == 'poly':

            height = np.max(points[:, 1]) + 10
            width = np.max(points[:, 0]) + 10

            mask = np.zeros((height, width), np.uint8)
            mask[points[:, 1], points[:, 0]] = 255

            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
                                           cv2.CHAIN_APPROX_SIMPLE)
            boundary = list(contours[0].flatten().tolist())

        if len(boundary) < 8:
            return []

        return boundary