File size: 1,260 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch

from .nrtr_postprocess import NRTRLabelDecode


class ABINetLabelDecode(NRTRLabelDecode):
    """Convert between text-label and text-index."""

    def __init__(self,
                 character_dict_path=None,
                 use_space_char=False,
                 **kwargs):
        super(ABINetLabelDecode, self).__init__(character_dict_path,
                                                use_space_char)

    def __call__(self, preds, batch=None, *args, **kwargs):
        if isinstance(preds, dict):
            if len(preds['align']) > 0:
                preds = preds['align'][-1].detach().cpu().numpy()
            else:
                preds = preds['vision'].detach().cpu().numpy()
        elif isinstance(preds, torch.Tensor):
            preds = preds.detach().cpu().numpy()
        else:
            preds = preds

        preds_idx = preds.argmax(axis=2)
        preds_prob = preds.max(axis=2)
        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
        if batch is None:
            return text
        label = self.decode(batch[1].cpu().numpy())
        return text, label

    def add_special_char(self, dict_character):
        dict_character = ['</s>'] + dict_character
        return dict_character