import torch from .nrtr_postprocess import NRTRLabelDecode class CPPDLabelDecode(NRTRLabelDecode): """Convert between text-label and text-index.""" def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): super(CPPDLabelDecode, self).__init__(character_dict_path, use_space_char) def __call__(self, preds, batch=None, *args, **kwargs): if isinstance(preds, tuple): if isinstance(preds[-1], dict): preds = preds[-1]['align'][-1].detach().cpu().numpy() else: preds = preds[-1].detach().cpu().numpy() if isinstance(preds, list): preds = preds[-1].detach().cpu().numpy() if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() elif isinstance(preds, dict): preds = preds['align'][-1].detach().cpu().numpy() else: preds = preds preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) if batch is None: return text label = batch[1] label = self.decode(label.detach().cpu().numpy()) return text, label def add_special_char(self, dict_character): dict_character = [''] + dict_character return dict_character