File size: 5,511 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from .ctc_postprocess import BaseRecLabelDecode


class MPGLabelDecode(BaseRecLabelDecode):
    """Convert between text-label and text-index."""
    SPACE = '[s]'
    GO = '[GO]'
    list_token = [GO, SPACE]

    def __init__(self,
                 character_dict_path=None,
                 use_space_char=False,
                 only_char=False,
                 **kwargs):
        super(MPGLabelDecode, self).__init__(character_dict_path,
                                             use_space_char)
        self.only_char = only_char
        self.EOS = '[s]'
        self.PAD = '[GO]'
        if not only_char:
            # transformers==4.2.1
            from transformers import BertTokenizer, GPT2Tokenizer
            self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            self.wp_tokenizer = BertTokenizer.from_pretrained(
                'bert-base-uncased')

    def __call__(self, preds, batch=None, *args, **kwargs):

        if isinstance(preds, list):
            char_preds = preds[0].detach().cpu().numpy()
        else:
            char_preds = preds.detach().cpu().numpy()

        preds_idx = char_preds.argmax(axis=2)
        preds_prob = char_preds.max(axis=2)
        char_text = self.char_decode(preds_idx[:, 1:], preds_prob[:, 1:])
        if batch is None:
            return char_text
        label = batch[1]
        label = self.char_decode(label[:, 1:].detach().cpu().numpy())
        if self.only_char:
            return char_text, label
        else:
            bpe_preds = preds[1].detach().cpu().numpy()
            wp_preds = preds[2]

            bpe_preds_idx = bpe_preds.argmax(axis=2)
            bpe_preds_prob = bpe_preds.max(axis=2)
            bpe_text = self.bpe_decode(bpe_preds_idx[:, 1:],
                                       bpe_preds_prob[:, 1:])

            wp_preds = wp_preds.detach()  #.cpu().numpy()
            wp_preds_prob, wp_preds_idx = wp_preds.max(-1)
            wp_text = self.wp_decode(wp_preds_idx[:, 1:], wp_preds_prob[:, 1:])

            final_text = self.final_decode(char_text, bpe_text, wp_text)
            return char_text, bpe_text, wp_text, final_text, label

    def add_special_char(self, dict_character):
        dict_character = self.list_token + dict_character
        return dict_character

    def final_decode(self, char_text, bpe_text, wp_text):
        result_list = []
        for (char_pred,
             char_pred_conf), (bpe_pred,
                               bpe_pred_conf), (wp_pred, wp_pred_conf) in zip(
                                   char_text, bpe_text, wp_text):
            final_text = char_pred
            final_prob = char_pred_conf
            if bpe_pred_conf > final_prob:
                final_text = bpe_pred
                final_prob = bpe_pred_conf
            if wp_pred_conf > final_prob:
                final_text = wp_pred
                final_prob = wp_pred_conf
            result_list.append((final_text, final_prob))
        return result_list

    def char_decode(self, text_index, text_prob=None):
        """ convert text-index into text-label. """
        result_list = []
        batch_size = len(text_index)
        for batch_idx in range(batch_size):
            char_list = []
            conf_list = 1.0
            for idx in range(len(text_index[batch_idx])):
                try:
                    char_idx = self.character[int(text_index[batch_idx][idx])]
                except:
                    continue
                if text_prob is not None:
                    conf_list *= text_prob[batch_idx][idx]

                if char_idx == self.EOS:  # end
                    break
                if char_idx == self.PAD:
                    continue
                char_list.append(char_idx)

            text = ''.join(char_list)
            result_list.append((text, conf_list))
        return result_list

    def bpe_decode(self, text_index, text_prob):
        """ convert text-index into text-label. """
        result_list = []
        for text, probs in zip(text_index, text_prob):
            text_decoded = []
            conf_list = 1.0
            for bpeindx, prob in zip(text, probs):
                tokenstr = self.bpe_tokenizer.decode([bpeindx])
                if tokenstr == '#':
                    break
                text_decoded.append(tokenstr)
                conf_list *= prob
            text = ''.join(text_decoded)
            result_list.append((text, conf_list))
        return result_list

    def wp_decode(self, text_index, text_prob=None):
        """ convert text-index into text-label. """
        result_list = []
        for batch_idx, text in enumerate(text_index):
            wp_pred = self.wp_tokenizer.decode(text)
            wp_pred_EOS = wp_pred.find('[SEP]')
            wp_pred = wp_pred[:wp_pred_EOS]
            if text_prob is not None:
                try:
                    # print(text.cpu().tolist())
                    wp_pred_EOS_index = text.cpu().tolist().index(102) + 1
                except:
                    wp_pred_EOS_index = -1
                wp_pred_max_prob = text_prob[batch_idx][:wp_pred_EOS_index]
                try:
                    wp_confidence_score = wp_pred_max_prob.cumprod(
                        dim=0)[-1].cpu().numpy().sum()
                except:
                    wp_confidence_score = 0.0
            else:
                wp_confidence_score = 1.0
            result_list.append((wp_pred, wp_confidence_score))
        return result_list