# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch from mmengine.structures import InstanceData from mmocr.models import Dictionary from mmocr.models.textrecog.postprocessors import BaseTextRecogPostprocessor from mmocr.registry import MODELS from mmocr.structures import TextSpottingDataSample from mmocr.utils import rescale_polygons @MODELS.register_module() class SPTSPostprocessor(BaseTextRecogPostprocessor): """PostProcessor for SPTS. Args: dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or the instance of `Dictionary`. num_bins (int): Number of bins dividing the image. Defaults to 1000. rescale_fields (list[str], optional): The bbox/polygon field names to be rescaled. If None, no rescaling will be performed. max_seq_len (int): Maximum sequence length. In SPTS, a sequence encodes all the text instances in a sample. Defaults to 40, which will be overridden by SPTSDecoder. ignore_chars (list[str]): A list of characters to be ignored from the final results. Postprocessor will skip over these characters when converting raw indexes to characters. Apart from single characters, each item can be one of the following reversed keywords: 'padding', 'end' and 'unknown', which refer to their corresponding special tokens in the dictionary. """ def __init__(self, dictionary: Union[Dictionary, Dict], num_bins: int, rescale_fields: Optional[Sequence[str]] = ['points'], max_seq_len: int = 40, ignore_chars: Sequence[str] = ['padding'], **kwargs) -> None: assert rescale_fields is None or isinstance(rescale_fields, list) self.num_bins = num_bins self.rescale_fields = rescale_fields super().__init__( dictionary=dictionary, num_bins=num_bins, max_seq_len=max_seq_len, ignore_chars=ignore_chars) def get_single_prediction( self, max_probs: torch.Tensor, seq: torch.Tensor, data_sample: Optional[TextSpottingDataSample] = None, ) -> Tuple[List[List[int]], List[List[float]], List[Tuple[float]], List[Tuple[float]]]: """Convert the output probabilities of a single image to character indexes, character scores, points and point scores. Args: max_probs (torch.Tensor): Character probabilities with shape :math:`(T)`. seq (torch.Tensor): Sequence indexes with shape :math:`(T)`. data_sample (TextSpottingDataSample, optional): Datasample of an image. Defaults to None. Returns: tuple(list[list[int]], list[list[float]], list[(float, float)], list(float, float)): character indexes, character scores, points and point scores. Each has len of max_seq_len. """ h, w = data_sample.img_shape # the if is not a must since the softmaxed are masked out in decoder # if len(max_probs) % 27 != 0: # max_probs = max_probs[:-len(max_probs) % 27] # seq = seq[:-len(seq) % 27] # max_value, max_idx = torch.max(max_probs, -1) max_probs = max_probs.reshape(-1, 27) seq = seq.reshape(-1, 27) indexes, text_scores, points, pt_scores = [], [], [], [] output_indexes = seq.cpu().detach().numpy().tolist() output_scores = max_probs.cpu().detach().numpy().tolist() for output_index, output_score in zip(output_indexes, output_scores): # work for multi-batch # if output_index[0] == self.dictionary.seq_end_idx +self.num_bins: # break point_x = output_index[0] / self.num_bins * w point_y = output_index[1] / self.num_bins * h points.append((point_x, point_y)) pt_scores.append( np.mean([ output_score[0], output_score[1], ]).item()) indexes.append([]) char_scores = [] for char_index, char_score in zip(output_index[2:], output_score[2:]): # the first num_bins indexes are for points if char_index in self.ignore_indexes: continue if char_index == self.dictionary.end_idx: break indexes[-1].append(char_index) char_scores.append(char_score) text_scores.append(np.mean(char_scores).item()) return indexes, text_scores, points, pt_scores def __call__( self, output: Tuple[torch.Tensor, torch.Tensor], data_samples: Sequence[TextSpottingDataSample] ) -> Sequence[TextSpottingDataSample]: """Convert outputs to strings and scores. Args: output (tuple(Tensor, Tensor)): A tuple of (probs, seq), each has the shape of :math:`(T,)`. data_samples (list[TextSpottingDataSample]): The list of TextSpottingDataSample. Returns: list(TextSpottingDataSample): The list of TextSpottingDataSample. """ max_probs, seq = output batch_size = max_probs.size(0) for idx in range(batch_size): (char_idxs, text_scores, points, pt_scores) = self.get_single_prediction(max_probs[idx, :], seq[idx, :], data_samples[idx]) texts = [] scores = [] for index, pt_score in zip(char_idxs, pt_scores): text = self.dictionary.idx2str(index) texts.append(text) # the "scores" field only accepts a float number scores.append(np.mean(pt_score).item()) pred_instances = InstanceData() pred_instances.texts = texts pred_instances.scores = scores pred_instances.text_scores = text_scores pred_instances.points = points data_samples[idx].pred_instances = pred_instances pred_instances = self.rescale(data_samples[idx], data_samples[idx].scale_factor) return data_samples def rescale(self, results: TextSpottingDataSample, scale_factor: Sequence[int]) -> TextSpottingDataSample: """Rescale results in ``results.pred_instances`` according to ``scale_factor``, whose keys are defined in ``self.rescale_fields``. Usually used to rescale bboxes and/or polygons. Args: results (TextSpottingDataSample): The post-processed prediction results. scale_factor (tuple(int)): (w_scale, h_scale) Returns: TextDetDataSample: Prediction results with rescaled results. """ scale_factor = np.asarray(scale_factor) for key in self.rescale_fields: # TODO: this util may need an alias or to be renamed results.pred_instances[key] = rescale_polygons( results.pred_instances[key], scale_factor, mode='div') return results