Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Dict, List, Optional, Sequence, Union | |
import mmcv | |
import numpy as np | |
import torch | |
from matplotlib.collections import PatchCollection | |
from matplotlib.patches import FancyArrow | |
from mmengine.visualization import Visualizer | |
from mmengine.visualization.utils import (check_type, check_type_and_length, | |
color_val_matplotlib, tensor2ndarray, | |
value2list) | |
from mmocr.registry import VISUALIZERS | |
from mmocr.structures import KIEDataSample | |
from .base_visualizer import BaseLocalVisualizer | |
class KIELocalVisualizer(BaseLocalVisualizer): | |
"""The MMOCR Text Detection Local Visualizer. | |
Args: | |
name (str): Name of the instance. Defaults to 'visualizer'. | |
image (np.ndarray, optional): the origin image to draw. The format | |
should be RGB. Defaults to None. | |
vis_backends (list, optional): Visual backend config list. | |
Default to None. | |
save_dir (str, optional): Save file dir for all storage backends. | |
If it is None, the backend storage will not save any data. | |
fig_save_cfg (dict): Keyword parameters of figure for saving. | |
Defaults to empty dict. | |
fig_show_cfg (dict): Keyword parameters of figure for showing. | |
Defaults to empty dict. | |
is_openset (bool, optional): Whether the visualizer is used in | |
OpenSet. Defaults to False. | |
""" | |
def __init__(self, | |
name: str = 'kie_visualizer', | |
is_openset: bool = False, | |
**kwargs) -> None: | |
super().__init__(name=name, **kwargs) | |
self.is_openset = is_openset | |
def _draw_edge_label(self, | |
image: np.ndarray, | |
edge_labels: Union[np.ndarray, torch.Tensor], | |
bboxes: Union[np.ndarray, torch.Tensor], | |
texts: Sequence[str], | |
arrow_colors: str = 'g') -> np.ndarray: | |
"""Draw edge labels on image. | |
Args: | |
image (np.ndarray): The origin image to draw. The format | |
should be RGB. | |
edge_labels (np.ndarray or torch.Tensor): The edge labels to draw. | |
The shape of edge_labels should be (N, N), where N is the | |
number of texts. | |
bboxes (np.ndarray or torch.Tensor): The bboxes to draw. The shape | |
of bboxes should be (N, 4), where N is the number of texts. | |
texts (Sequence[str]): The texts to draw. The length of texts | |
should be the same as the number of bboxes. | |
arrow_colors (str, optional): The colors of arrows. Refer to | |
`matplotlib.colors` for full list of formats that are accepted. | |
Defaults to 'g'. | |
Returns: | |
np.ndarray: The image with edge labels drawn. | |
""" | |
pairs = np.where(edge_labels > 0) | |
if torch.is_tensor(pairs): | |
pairs = pairs.cpu() | |
key_bboxes = bboxes[pairs[0]] | |
value_bboxes = bboxes[pairs[1]] | |
x_data = np.stack([(key_bboxes[:, 2] + key_bboxes[:, 0]) / 2, | |
(value_bboxes[:, 0] + value_bboxes[:, 2]) / 2], | |
axis=-1) | |
y_data = np.stack([(key_bboxes[:, 1] + key_bboxes[:, 3]) / 2, | |
(value_bboxes[:, 1] + value_bboxes[:, 3]) / 2], | |
axis=-1) | |
key_index = np.array(list(set(pairs[0]))) | |
val_index = np.array(list(set(pairs[1]))) | |
key_texts = [texts[i] for i in key_index] | |
val_texts = [texts[i] for i in val_index] | |
self.set_image(image) | |
if key_texts: | |
self.draw_texts( | |
key_texts, (bboxes[key_index, :2] + bboxes[key_index, 2:]) / 2, | |
colors='k', | |
horizontal_alignments='center', | |
vertical_alignments='center', | |
font_families=self.font_families, | |
font_properties=self.font_properties) | |
if val_texts: | |
self.draw_texts( | |
val_texts, (bboxes[val_index, :2] + bboxes[val_index, 2:]) / 2, | |
colors='k', | |
horizontal_alignments='center', | |
vertical_alignments='center', | |
font_families=self.font_families, | |
font_properties=self.font_properties) | |
self.draw_arrows( | |
x_data, | |
y_data, | |
colors=arrow_colors, | |
line_widths=0.3, | |
arrow_tail_widths=0.05, | |
arrow_head_widths=5, | |
overhangs=1, | |
arrow_shapes='full') | |
return self.get_image() | |
def _draw_instances( | |
self, | |
image: np.ndarray, | |
bbox_labels: Union[np.ndarray, torch.Tensor], | |
bboxes: Union[np.ndarray, torch.Tensor], | |
polygons: Sequence[np.ndarray], | |
edge_labels: Union[np.ndarray, torch.Tensor], | |
texts: Sequence[str], | |
class_names: Dict, | |
is_openset: bool = False, | |
arrow_colors: str = 'g', | |
) -> np.ndarray: | |
"""Draw instances on image. | |
Args: | |
image (np.ndarray): The origin image to draw. The format | |
should be RGB. | |
bbox_labels (np.ndarray or torch.Tensor): The bbox labels to draw. | |
The shape of bbox_labels should be (N,), where N is the | |
number of texts. | |
bboxes (np.ndarray or torch.Tensor): The bboxes to draw. The shape | |
of bboxes should be (N, 4), where N is the number of texts. | |
polygons (Sequence[np.ndarray]): The polygons to draw. The length | |
of polygons should be the same as the number of bboxes. | |
edge_labels (np.ndarray or torch.Tensor): The edge labels to draw. | |
The shape of edge_labels should be (N, N), where N is the | |
number of texts. | |
texts (Sequence[str]): The texts to draw. The length of texts | |
should be the same as the number of bboxes. | |
class_names (dict): The class names for bbox labels. | |
is_openset (bool): Whether the dataset is openset. Defaults to | |
False. | |
arrow_colors (str, optional): The colors of arrows. Refer to | |
`matplotlib.colors` for full list of formats that are accepted. | |
Defaults to 'g'. | |
Returns: | |
np.ndarray: The image with instances drawn. | |
""" | |
img_shape = image.shape[:2] | |
empty_shape = (img_shape[0], img_shape[1], 3) | |
text_image = np.full(empty_shape, 255, dtype=np.uint8) | |
text_image = self.get_labels_image( | |
text_image, | |
texts, | |
bboxes, | |
font_families=self.font_families, | |
font_properties=self.font_properties) | |
classes_image = np.full(empty_shape, 255, dtype=np.uint8) | |
bbox_classes = [class_names[int(i)]['name'] for i in bbox_labels] | |
classes_image = self.get_labels_image( | |
classes_image, | |
bbox_classes, | |
bboxes, | |
font_families=self.font_families, | |
font_properties=self.font_properties) | |
if polygons: | |
polygons = [polygon.reshape(-1, 2) for polygon in polygons] | |
image = self.get_polygons_image( | |
image, polygons, filling=True, colors=self.PALETTE) | |
text_image = self.get_polygons_image( | |
text_image, polygons, colors=self.PALETTE) | |
classes_image = self.get_polygons_image( | |
classes_image, polygons, colors=self.PALETTE) | |
else: | |
image = self.get_bboxes_image( | |
image, bboxes, filling=True, colors=self.PALETTE) | |
text_image = self.get_bboxes_image( | |
text_image, bboxes, colors=self.PALETTE) | |
classes_image = self.get_bboxes_image( | |
classes_image, bboxes, colors=self.PALETTE) | |
cat_image = [image, text_image, classes_image] | |
if is_openset: | |
edge_image = np.full(empty_shape, 255, dtype=np.uint8) | |
edge_image = self._draw_edge_label(edge_image, edge_labels, bboxes, | |
texts, arrow_colors) | |
cat_image.append(edge_image) | |
return self._cat_image(cat_image, axis=1) | |
def add_datasample(self, | |
name: str, | |
image: np.ndarray, | |
data_sample: Optional['KIEDataSample'] = None, | |
draw_gt: bool = True, | |
draw_pred: bool = True, | |
show: bool = False, | |
wait_time: int = 0, | |
pred_score_thr: float = None, | |
out_file: Optional[str] = None, | |
step: int = 0) -> None: | |
"""Draw datasample and save to all backends. | |
- If GT and prediction are plotted at the same time, they are | |
displayed in a stitched image where the left image is the | |
ground truth and the right image is the prediction. | |
- If ``show`` is True, all storage backends are ignored, and | |
the images will be displayed in a local window. | |
- If ``out_file`` is specified, the drawn image will be | |
saved to ``out_file``. This is usually used when the display | |
is not available. | |
Args: | |
name (str): The image identifier. | |
image (np.ndarray): The image to draw. | |
data_sample (:obj:`KIEDataSample`, optional): | |
KIEDataSample which contains gt and prediction. Defaults | |
to None. | |
draw_gt (bool): Whether to draw GT KIEDataSample. | |
Defaults to True. | |
draw_pred (bool): Whether to draw Predicted KIEDataSample. | |
Defaults to True. | |
show (bool): Whether to display the drawn image. Default to False. | |
wait_time (float): The interval of show (s). Defaults to 0. | |
pred_score_thr (float): The threshold to visualize the bboxes | |
and masks. Defaults to 0.3. | |
out_file (str): Path to output file. Defaults to None. | |
step (int): Global step value to record. Defaults to 0. | |
""" | |
cat_images = list() | |
if draw_gt: | |
gt_bboxes = data_sample.gt_instances.bboxes | |
gt_labels = data_sample.gt_instances.labels | |
gt_texts = data_sample.gt_instances.texts | |
gt_polygons = data_sample.gt_instances.get('polygons', None) | |
gt_edge_labels = data_sample.gt_instances.get('edge_labels', None) | |
gt_img_data = self._draw_instances(image, gt_labels, gt_bboxes, | |
gt_polygons, gt_edge_labels, | |
gt_texts, | |
self.dataset_meta['category'], | |
self.is_openset, 'g') | |
cat_images.append(gt_img_data) | |
if draw_pred: | |
gt_bboxes = data_sample.gt_instances.bboxes | |
pred_labels = data_sample.pred_instances.labels | |
gt_texts = data_sample.gt_instances.texts | |
gt_polygons = data_sample.gt_instances.get('polygons', None) | |
pred_edge_labels = data_sample.pred_instances.get( | |
'edge_labels', None) | |
pred_img_data = self._draw_instances(image, pred_labels, gt_bboxes, | |
gt_polygons, pred_edge_labels, | |
gt_texts, | |
self.dataset_meta['category'], | |
self.is_openset, 'r') | |
cat_images.append(pred_img_data) | |
cat_images = self._cat_image(cat_images, axis=0) | |
if cat_images is None: | |
cat_images = image | |
if show: | |
self.show(cat_images, win_name=name, wait_time=wait_time) | |
else: | |
self.add_image(name, cat_images, step) | |
if out_file is not None: | |
mmcv.imwrite(cat_images[..., ::-1], out_file) | |
self.set_image(cat_images) | |
return self.get_image() | |
def draw_arrows(self, | |
x_data: Union[np.ndarray, torch.Tensor], | |
y_data: Union[np.ndarray, torch.Tensor], | |
colors: Union[str, tuple, List[str], List[tuple]] = 'C1', | |
line_widths: Union[Union[int, float], | |
List[Union[int, float]]] = 1, | |
line_styles: Union[str, List[str]] = '-', | |
arrow_tail_widths: Union[Union[int, float], | |
List[Union[int, float]]] = 0.001, | |
arrow_head_widths: Union[Union[int, float], | |
List[Union[int, float]]] = None, | |
arrow_head_lengths: Union[Union[int, float], | |
List[Union[int, float]]] = None, | |
arrow_shapes: Union[str, List[str]] = 'full', | |
overhangs: Union[int, List[int]] = 0) -> 'Visualizer': | |
"""Draw single or multiple arrows. | |
Args: | |
x_data (np.ndarray or torch.Tensor): The x coordinate of | |
each line' start and end points. | |
y_data (np.ndarray, torch.Tensor): The y coordinate of | |
each line' start and end points. | |
colors (str or tuple or list[str or tuple]): The colors of | |
lines. ``colors`` can have the same length with lines or just | |
single value. If ``colors`` is single value, all the lines | |
will have the same colors. Reference to | |
https://matplotlib.org/stable/gallery/color/named_colors.html | |
for more details. Defaults to 'g'. | |
line_widths (int or float or list[int or float]): | |
The linewidth of lines. ``line_widths`` can have | |
the same length with lines or just single value. | |
If ``line_widths`` is single value, all the lines will | |
have the same linewidth. Defaults to 2. | |
line_styles (str or list[str]]): The linestyle of lines. | |
``line_styles`` can have the same length with lines or just | |
single value. If ``line_styles`` is single value, all the | |
lines will have the same linestyle. Defaults to '-'. | |
arrow_tail_widths (int or float or list[int, float]): | |
The width of arrow tails. ``arrow_tail_widths`` can have | |
the same length with lines or just single value. If | |
``arrow_tail_widths`` is single value, all the lines will | |
have the same width. Defaults to 0.001. | |
arrow_head_widths (int or float or list[int, float]): | |
The width of arrow heads. ``arrow_head_widths`` can have | |
the same length with lines or just single value. If | |
``arrow_head_widths`` is single value, all the lines will | |
have the same width. Defaults to None. | |
arrow_head_lengths (int or float or list[int, float]): | |
The length of arrow heads. ``arrow_head_lengths`` can have | |
the same length with lines or just single value. If | |
``arrow_head_lengths`` is single value, all the lines will | |
have the same length. Defaults to None. | |
arrow_shapes (str or list[str]]): The shapes of arrow heads. | |
``arrow_shapes`` can have the same length with lines or just | |
single value. If ``arrow_shapes`` is single value, all the | |
lines will have the same shape. Defaults to 'full'. | |
overhangs (int or list[int]]): The overhangs of arrow heads. | |
``overhangs`` can have the same length with lines or just | |
single value. If ``overhangs`` is single value, all the lines | |
will have the same overhangs. Defaults to 0. | |
""" | |
check_type('x_data', x_data, (np.ndarray, torch.Tensor)) | |
x_data = tensor2ndarray(x_data) | |
check_type('y_data', y_data, (np.ndarray, torch.Tensor)) | |
y_data = tensor2ndarray(y_data) | |
assert x_data.shape == y_data.shape, ( | |
'`x_data` and `y_data` should have the same shape') | |
assert x_data.shape[-1] == 2, ( | |
f'The shape of `x_data` should be (N, 2), but got {x_data.shape}') | |
if len(x_data.shape) == 1: | |
x_data = x_data[None] | |
y_data = y_data[None] | |
number_arrow = x_data.shape[0] | |
check_type_and_length('colors', colors, (str, tuple, list), | |
number_arrow) | |
colors = value2list(colors, (str, tuple), number_arrow) | |
colors = color_val_matplotlib(colors) # type: ignore | |
check_type_and_length('line_widths', line_widths, (int, float), | |
number_arrow) | |
line_widths = value2list(line_widths, (int, float), number_arrow) | |
check_type_and_length('arrow_tail_widths', arrow_tail_widths, | |
(int, float), number_arrow) | |
check_type_and_length('line_styles', line_styles, str, number_arrow) | |
line_styles = value2list(line_styles, str, number_arrow) | |
arrow_tail_widths = value2list(arrow_tail_widths, (int, float), | |
number_arrow) | |
check_type_and_length('arrow_head_widths', arrow_head_widths, | |
(int, float, type(None)), number_arrow) | |
arrow_head_widths = value2list(arrow_head_widths, | |
(int, float, type(None)), number_arrow) | |
check_type_and_length('arrow_head_lengths', arrow_head_lengths, | |
(int, float, type(None)), number_arrow) | |
arrow_head_lengths = value2list(arrow_head_lengths, | |
(int, float, type(None)), number_arrow) | |
check_type_and_length('arrow_shapes', arrow_shapes, (str, list), | |
number_arrow) | |
arrow_shapes = value2list(arrow_shapes, (str, list), number_arrow) | |
check_type('overhang', overhangs, int) | |
overhangs = value2list(overhangs, int, number_arrow) | |
lines = np.concatenate( | |
(x_data.reshape(-1, 2, 1), y_data.reshape(-1, 2, 1)), axis=-1) | |
if not self._is_posion_valid(lines): | |
warnings.warn( | |
'Warning: The line is out of bounds,' | |
' the drawn line may not be in the image', UserWarning) | |
arrows = [] | |
for i in range(number_arrow): | |
arrows.append( | |
FancyArrow( | |
*tuple(lines[i, 0]), | |
*tuple(lines[i, 1] - lines[i, 0]), | |
linestyle=line_styles[i], | |
color=colors[i], | |
length_includes_head=True, | |
width=arrow_tail_widths[i], | |
head_width=arrow_head_widths[i], | |
head_length=arrow_head_lengths[i], | |
overhang=overhangs[i], | |
shape=arrow_shapes[i], | |
linewidth=line_widths[i])) | |
p = PatchCollection(arrows, match_original=True) | |
self.ax_save.add_collection(p) | |
return self | |