Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Sequence, Union | |
import mmcv | |
import numpy as np | |
import torch | |
from mmocr.registry import VISUALIZERS | |
from mmocr.structures import TextDetDataSample | |
from mmocr.utils.polygon_utils import poly2bbox | |
from .base_visualizer import BaseLocalVisualizer | |
class TextSpottingLocalVisualizer(BaseLocalVisualizer): | |
def _draw_instances( | |
self, | |
image: np.ndarray, | |
bboxes: Union[np.ndarray, torch.Tensor], | |
polygons: Sequence[np.ndarray], | |
texts: Sequence[str], | |
) -> np.ndarray: | |
"""Draw instances on image. | |
Args: | |
image (np.ndarray): The origin image to draw. The format | |
should be RGB. | |
bboxes (np.ndarray, torch.Tensor): The bboxes to draw. The shape of | |
bboxes should be (N, 4), where N is the number of texts. | |
polygons (Sequence[np.ndarray]): The polygons to draw. The length | |
of polygons should be the same as the number of bboxes. | |
edge_labels (np.ndarray, torch.Tensor): The edge labels to draw. | |
The shape of edge_labels should be (N, N), where N is the | |
number of texts. | |
texts (Sequence[str]): The texts to draw. The length of texts | |
should be the same as the number of bboxes. | |
class_names (dict): The class names for bbox labels. | |
is_openset (bool): Whether the dataset is openset. Default: False. | |
Returns: | |
np.ndarray: The image with instances drawn. | |
""" | |
img_shape = image.shape[:2] | |
empty_shape = (img_shape[0], img_shape[1], 3) | |
text_image = np.full(empty_shape, 255, dtype=np.uint8) | |
if texts: | |
text_image = self.get_labels_image( | |
text_image, | |
labels=texts, | |
bboxes=bboxes, | |
font_families=self.font_families, | |
font_properties=self.font_properties) | |
if polygons: | |
polygons = [polygon.reshape(-1, 2) for polygon in polygons] | |
image = self.get_polygons_image( | |
image, polygons, filling=True, colors=self.PALETTE) | |
text_image = self.get_polygons_image( | |
text_image, polygons, colors=self.PALETTE) | |
elif len(bboxes) > 0: | |
image = self.get_bboxes_image( | |
image, bboxes, filling=True, colors=self.PALETTE) | |
text_image = self.get_bboxes_image( | |
text_image, bboxes, colors=self.PALETTE) | |
return np.concatenate([image, text_image], axis=1) | |
def add_datasample(self, | |
name: str, | |
image: np.ndarray, | |
data_sample: Optional['TextDetDataSample'] = None, | |
draw_gt: bool = True, | |
draw_pred: bool = True, | |
show: bool = False, | |
wait_time: int = 0, | |
pred_score_thr: float = 0.5, | |
out_file: Optional[str] = None, | |
step: int = 0) -> None: | |
"""Draw datasample and save to all backends. | |
- If GT and prediction are plotted at the same time, they are | |
displayed in a stitched image where the left image is the | |
ground truth and the right image is the prediction. | |
- If ``show`` is True, all storage backends are ignored, and | |
the images will be displayed in a local window. | |
- If ``out_file`` is specified, the drawn image will be | |
saved to ``out_file``. This is usually used when the display | |
is not available. | |
Args: | |
name (str): The image identifier. | |
image (np.ndarray): The image to draw. | |
data_sample (:obj:`TextSpottingDataSample`, optional): | |
TextDetDataSample which contains gt and prediction. Defaults | |
to None. | |
draw_gt (bool): Whether to draw GT TextDetDataSample. | |
Defaults to True. | |
draw_pred (bool): Whether to draw Predicted TextDetDataSample. | |
Defaults to True. | |
show (bool): Whether to display the drawn image. Default to False. | |
wait_time (float): The interval of show (s). Defaults to 0. | |
out_file (str): Path to output file. Defaults to None. | |
pred_score_thr (float): The threshold to visualize the bboxes | |
and masks. Defaults to 0.3. | |
step (int): Global step value to record. Defaults to 0. | |
""" | |
cat_images = [] | |
if data_sample is not None: | |
if draw_gt and 'gt_instances' in data_sample: | |
gt_bboxes = data_sample.gt_instances.get('bboxes', None) | |
gt_texts = data_sample.gt_instances.texts | |
gt_polygons = data_sample.gt_instances.get('polygons', None) | |
gt_img_data = self._draw_instances(image, gt_bboxes, | |
gt_polygons, gt_texts) | |
cat_images.append(gt_img_data) | |
if draw_pred and 'pred_instances' in data_sample: | |
pred_instances = data_sample.pred_instances | |
pred_instances = pred_instances[ | |
pred_instances.scores > pred_score_thr].cpu().numpy() | |
pred_bboxes = pred_instances.get('bboxes', None) | |
pred_texts = pred_instances.texts | |
pred_polygons = pred_instances.get('polygons', None) | |
if pred_bboxes is None: | |
pred_bboxes = [poly2bbox(poly) for poly in pred_polygons] | |
pred_bboxes = np.array(pred_bboxes) | |
pred_img_data = self._draw_instances(image, pred_bboxes, | |
pred_polygons, pred_texts) | |
cat_images.append(pred_img_data) | |
cat_images = self._cat_image(cat_images, axis=0) | |
if cat_images is None: | |
cat_images = image | |
if show: | |
self.show(cat_images, win_name=name, wait_time=wait_time) | |
else: | |
self.add_image(name, cat_images, step) | |
if out_file is not None: | |
mmcv.imwrite(cat_images[..., ::-1], out_file) | |
self.set_image(cat_images) | |
return self.get_image() | |