File size: 6,362 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union

import mmcv
import numpy as np
import torch

from mmocr.registry import VISUALIZERS
from mmocr.structures import TextDetDataSample
from mmocr.utils.polygon_utils import poly2bbox
from .base_visualizer import BaseLocalVisualizer


@VISUALIZERS.register_module()
class TextSpottingLocalVisualizer(BaseLocalVisualizer):

    def _draw_instances(
        self,
        image: np.ndarray,
        bboxes: Union[np.ndarray, torch.Tensor],
        polygons: Sequence[np.ndarray],
        texts: Sequence[str],
    ) -> np.ndarray:
        """Draw instances on image.

        Args:
            image (np.ndarray): The origin image to draw. The format
                should be RGB.
            bboxes (np.ndarray, torch.Tensor): The bboxes to draw. The shape of
                bboxes should be (N, 4), where N is the number of texts.
            polygons (Sequence[np.ndarray]): The polygons to draw. The length
                of polygons should be the same as the number of bboxes.
            edge_labels (np.ndarray, torch.Tensor): The edge labels to draw.
                The shape of edge_labels should be (N, N), where N is the
                number of texts.
            texts (Sequence[str]): The texts to draw. The length of texts
                should be the same as the number of bboxes.
            class_names (dict): The class names for bbox labels.
            is_openset (bool): Whether the dataset is openset. Default: False.

        Returns:
            np.ndarray: The image with instances drawn.
        """
        img_shape = image.shape[:2]
        empty_shape = (img_shape[0], img_shape[1], 3)
        text_image = np.full(empty_shape, 255, dtype=np.uint8)
        if texts:
            text_image = self.get_labels_image(
                text_image,
                labels=texts,
                bboxes=bboxes,
                font_families=self.font_families,
                font_properties=self.font_properties)
        if polygons:
            polygons = [polygon.reshape(-1, 2) for polygon in polygons]
            image = self.get_polygons_image(
                image, polygons, filling=True, colors=self.PALETTE)
            text_image = self.get_polygons_image(
                text_image, polygons, colors=self.PALETTE)
        elif len(bboxes) > 0:
            image = self.get_bboxes_image(
                image, bboxes, filling=True, colors=self.PALETTE)
            text_image = self.get_bboxes_image(
                text_image, bboxes, colors=self.PALETTE)
        return np.concatenate([image, text_image], axis=1)

    def add_datasample(self,
                       name: str,
                       image: np.ndarray,
                       data_sample: Optional['TextDetDataSample'] = None,
                       draw_gt: bool = True,
                       draw_pred: bool = True,
                       show: bool = False,
                       wait_time: int = 0,
                       pred_score_thr: float = 0.5,
                       out_file: Optional[str] = None,
                       step: int = 0) -> None:
        """Draw datasample and save to all backends.

        - If GT and prediction are plotted at the same time, they are
        displayed in a stitched image where the left image is the
        ground truth and the right image is the prediction.
        - If ``show`` is True, all storage backends are ignored, and
        the images will be displayed in a local window.
        - If ``out_file`` is specified, the drawn image will be
        saved to ``out_file``. This is usually used when the display
        is not available.

        Args:
            name (str): The image identifier.
            image (np.ndarray): The image to draw.
            data_sample (:obj:`TextSpottingDataSample`, optional):
                TextDetDataSample which contains gt and prediction. Defaults
                    to None.
            draw_gt (bool): Whether to draw GT TextDetDataSample.
                Defaults to True.
            draw_pred (bool): Whether to draw Predicted TextDetDataSample.
                Defaults to True.
            show (bool): Whether to display the drawn image. Default to False.
            wait_time (float): The interval of show (s). Defaults to 0.
            out_file (str): Path to output file. Defaults to None.
            pred_score_thr (float): The threshold to visualize the bboxes
                and masks. Defaults to 0.3.
            step (int): Global step value to record. Defaults to 0.
        """
        cat_images = []

        if data_sample is not None:
            if draw_gt and 'gt_instances' in data_sample:
                gt_bboxes = data_sample.gt_instances.get('bboxes', None)
                gt_texts = data_sample.gt_instances.texts
                gt_polygons = data_sample.gt_instances.get('polygons', None)
                gt_img_data = self._draw_instances(image, gt_bboxes,
                                                   gt_polygons, gt_texts)
                cat_images.append(gt_img_data)

            if draw_pred and 'pred_instances' in data_sample:
                pred_instances = data_sample.pred_instances
                pred_instances = pred_instances[
                    pred_instances.scores > pred_score_thr].cpu().numpy()
                pred_bboxes = pred_instances.get('bboxes', None)
                pred_texts = pred_instances.texts
                pred_polygons = pred_instances.get('polygons', None)
                if pred_bboxes is None:
                    pred_bboxes = [poly2bbox(poly) for poly in pred_polygons]
                    pred_bboxes = np.array(pred_bboxes)
                pred_img_data = self._draw_instances(image, pred_bboxes,
                                                     pred_polygons, pred_texts)
                cat_images.append(pred_img_data)

        cat_images = self._cat_image(cat_images, axis=0)
        if cat_images is None:
            cat_images = image

        if show:
            self.show(cat_images, win_name=name, wait_time=wait_time)
        else:
            self.add_image(name, cat_images, step)

        if out_file is not None:
            mmcv.imwrite(cat_images[..., ::-1], out_file)

        self.set_image(cat_images)
        return self.get_image()