Spaces:
Sleeping
Sleeping
File size: 6,362 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union
import mmcv
import numpy as np
import torch
from mmocr.registry import VISUALIZERS
from mmocr.structures import TextDetDataSample
from mmocr.utils.polygon_utils import poly2bbox
from .base_visualizer import BaseLocalVisualizer
@VISUALIZERS.register_module()
class TextSpottingLocalVisualizer(BaseLocalVisualizer):
def _draw_instances(
self,
image: np.ndarray,
bboxes: Union[np.ndarray, torch.Tensor],
polygons: Sequence[np.ndarray],
texts: Sequence[str],
) -> np.ndarray:
"""Draw instances on image.
Args:
image (np.ndarray): The origin image to draw. The format
should be RGB.
bboxes (np.ndarray, torch.Tensor): The bboxes to draw. The shape of
bboxes should be (N, 4), where N is the number of texts.
polygons (Sequence[np.ndarray]): The polygons to draw. The length
of polygons should be the same as the number of bboxes.
edge_labels (np.ndarray, torch.Tensor): The edge labels to draw.
The shape of edge_labels should be (N, N), where N is the
number of texts.
texts (Sequence[str]): The texts to draw. The length of texts
should be the same as the number of bboxes.
class_names (dict): The class names for bbox labels.
is_openset (bool): Whether the dataset is openset. Default: False.
Returns:
np.ndarray: The image with instances drawn.
"""
img_shape = image.shape[:2]
empty_shape = (img_shape[0], img_shape[1], 3)
text_image = np.full(empty_shape, 255, dtype=np.uint8)
if texts:
text_image = self.get_labels_image(
text_image,
labels=texts,
bboxes=bboxes,
font_families=self.font_families,
font_properties=self.font_properties)
if polygons:
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
image = self.get_polygons_image(
image, polygons, filling=True, colors=self.PALETTE)
text_image = self.get_polygons_image(
text_image, polygons, colors=self.PALETTE)
elif len(bboxes) > 0:
image = self.get_bboxes_image(
image, bboxes, filling=True, colors=self.PALETTE)
text_image = self.get_bboxes_image(
text_image, bboxes, colors=self.PALETTE)
return np.concatenate([image, text_image], axis=1)
def add_datasample(self,
name: str,
image: np.ndarray,
data_sample: Optional['TextDetDataSample'] = None,
draw_gt: bool = True,
draw_pred: bool = True,
show: bool = False,
wait_time: int = 0,
pred_score_thr: float = 0.5,
out_file: Optional[str] = None,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are
displayed in a stitched image where the left image is the
ground truth and the right image is the prediction.
- If ``show`` is True, all storage backends are ignored, and
the images will be displayed in a local window.
- If ``out_file`` is specified, the drawn image will be
saved to ``out_file``. This is usually used when the display
is not available.
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
data_sample (:obj:`TextSpottingDataSample`, optional):
TextDetDataSample which contains gt and prediction. Defaults
to None.
draw_gt (bool): Whether to draw GT TextDetDataSample.
Defaults to True.
draw_pred (bool): Whether to draw Predicted TextDetDataSample.
Defaults to True.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None.
pred_score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0.
"""
cat_images = []
if data_sample is not None:
if draw_gt and 'gt_instances' in data_sample:
gt_bboxes = data_sample.gt_instances.get('bboxes', None)
gt_texts = data_sample.gt_instances.texts
gt_polygons = data_sample.gt_instances.get('polygons', None)
gt_img_data = self._draw_instances(image, gt_bboxes,
gt_polygons, gt_texts)
cat_images.append(gt_img_data)
if draw_pred and 'pred_instances' in data_sample:
pred_instances = data_sample.pred_instances
pred_instances = pred_instances[
pred_instances.scores > pred_score_thr].cpu().numpy()
pred_bboxes = pred_instances.get('bboxes', None)
pred_texts = pred_instances.texts
pred_polygons = pred_instances.get('polygons', None)
if pred_bboxes is None:
pred_bboxes = [poly2bbox(poly) for poly in pred_polygons]
pred_bboxes = np.array(pred_bboxes)
pred_img_data = self._draw_instances(image, pred_bboxes,
pred_polygons, pred_texts)
cat_images.append(pred_img_data)
cat_images = self._cat_image(cat_images, axis=0)
if cat_images is None:
cat_images = image
if show:
self.show(cat_images, win_name=name, wait_time=wait_time)
else:
self.add_image(name, cat_images, step)
if out_file is not None:
mmcv.imwrite(cat_images[..., ::-1], out_file)
self.set_image(cat_images)
return self.get_image()
|