MAERec-Gradio / mmocr /visualization /textspotting_visualizer.py
Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
6.36 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union
import mmcv
import numpy as np
import torch
from mmocr.registry import VISUALIZERS
from mmocr.structures import TextDetDataSample
from mmocr.utils.polygon_utils import poly2bbox
from .base_visualizer import BaseLocalVisualizer
@VISUALIZERS.register_module()
class TextSpottingLocalVisualizer(BaseLocalVisualizer):
def _draw_instances(
self,
image: np.ndarray,
bboxes: Union[np.ndarray, torch.Tensor],
polygons: Sequence[np.ndarray],
texts: Sequence[str],
) -> np.ndarray:
"""Draw instances on image.
Args:
image (np.ndarray): The origin image to draw. The format
should be RGB.
bboxes (np.ndarray, torch.Tensor): The bboxes to draw. The shape of
bboxes should be (N, 4), where N is the number of texts.
polygons (Sequence[np.ndarray]): The polygons to draw. The length
of polygons should be the same as the number of bboxes.
edge_labels (np.ndarray, torch.Tensor): The edge labels to draw.
The shape of edge_labels should be (N, N), where N is the
number of texts.
texts (Sequence[str]): The texts to draw. The length of texts
should be the same as the number of bboxes.
class_names (dict): The class names for bbox labels.
is_openset (bool): Whether the dataset is openset. Default: False.
Returns:
np.ndarray: The image with instances drawn.
"""
img_shape = image.shape[:2]
empty_shape = (img_shape[0], img_shape[1], 3)
text_image = np.full(empty_shape, 255, dtype=np.uint8)
if texts:
text_image = self.get_labels_image(
text_image,
labels=texts,
bboxes=bboxes,
font_families=self.font_families,
font_properties=self.font_properties)
if polygons:
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
image = self.get_polygons_image(
image, polygons, filling=True, colors=self.PALETTE)
text_image = self.get_polygons_image(
text_image, polygons, colors=self.PALETTE)
elif len(bboxes) > 0:
image = self.get_bboxes_image(
image, bboxes, filling=True, colors=self.PALETTE)
text_image = self.get_bboxes_image(
text_image, bboxes, colors=self.PALETTE)
return np.concatenate([image, text_image], axis=1)
def add_datasample(self,
name: str,
image: np.ndarray,
data_sample: Optional['TextDetDataSample'] = None,
draw_gt: bool = True,
draw_pred: bool = True,
show: bool = False,
wait_time: int = 0,
pred_score_thr: float = 0.5,
out_file: Optional[str] = None,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are
displayed in a stitched image where the left image is the
ground truth and the right image is the prediction.
- If ``show`` is True, all storage backends are ignored, and
the images will be displayed in a local window.
- If ``out_file`` is specified, the drawn image will be
saved to ``out_file``. This is usually used when the display
is not available.
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
data_sample (:obj:`TextSpottingDataSample`, optional):
TextDetDataSample which contains gt and prediction. Defaults
to None.
draw_gt (bool): Whether to draw GT TextDetDataSample.
Defaults to True.
draw_pred (bool): Whether to draw Predicted TextDetDataSample.
Defaults to True.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None.
pred_score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0.
"""
cat_images = []
if data_sample is not None:
if draw_gt and 'gt_instances' in data_sample:
gt_bboxes = data_sample.gt_instances.get('bboxes', None)
gt_texts = data_sample.gt_instances.texts
gt_polygons = data_sample.gt_instances.get('polygons', None)
gt_img_data = self._draw_instances(image, gt_bboxes,
gt_polygons, gt_texts)
cat_images.append(gt_img_data)
if draw_pred and 'pred_instances' in data_sample:
pred_instances = data_sample.pred_instances
pred_instances = pred_instances[
pred_instances.scores > pred_score_thr].cpu().numpy()
pred_bboxes = pred_instances.get('bboxes', None)
pred_texts = pred_instances.texts
pred_polygons = pred_instances.get('polygons', None)
if pred_bboxes is None:
pred_bboxes = [poly2bbox(poly) for poly in pred_polygons]
pred_bboxes = np.array(pred_bboxes)
pred_img_data = self._draw_instances(image, pred_bboxes,
pred_polygons, pred_texts)
cat_images.append(pred_img_data)
cat_images = self._cat_image(cat_images, axis=0)
if cat_images is None:
cat_images = image
if show:
self.show(cat_images, win_name=name, wait_time=wait_time)
else:
self.add_image(name, cat_images, step)
if out_file is not None:
mmcv.imwrite(cat_images[..., ::-1], out_file)
self.set_image(cat_images)
return self.get_image()