Spaces:
Running
Running
from .ctc_postprocess import BaseRecLabelDecode | |
class MPGLabelDecode(BaseRecLabelDecode): | |
"""Convert between text-label and text-index.""" | |
SPACE = '[s]' | |
GO = '[GO]' | |
list_token = [GO, SPACE] | |
def __init__(self, | |
character_dict_path=None, | |
use_space_char=False, | |
only_char=False, | |
**kwargs): | |
super(MPGLabelDecode, self).__init__(character_dict_path, | |
use_space_char) | |
self.only_char = only_char | |
self.EOS = '[s]' | |
self.PAD = '[GO]' | |
if not only_char: | |
# transformers==4.2.1 | |
from transformers import BertTokenizer, GPT2Tokenizer | |
self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
self.wp_tokenizer = BertTokenizer.from_pretrained( | |
'bert-base-uncased') | |
def __call__(self, preds, batch=None, *args, **kwargs): | |
if isinstance(preds, list): | |
char_preds = preds[0].detach().cpu().numpy() | |
else: | |
char_preds = preds.detach().cpu().numpy() | |
preds_idx = char_preds.argmax(axis=2) | |
preds_prob = char_preds.max(axis=2) | |
char_text = self.char_decode(preds_idx[:, 1:], preds_prob[:, 1:]) | |
if batch is None: | |
return char_text | |
label = batch[1] | |
label = self.char_decode(label[:, 1:].detach().cpu().numpy()) | |
if self.only_char: | |
return char_text, label | |
else: | |
bpe_preds = preds[1].detach().cpu().numpy() | |
wp_preds = preds[2] | |
bpe_preds_idx = bpe_preds.argmax(axis=2) | |
bpe_preds_prob = bpe_preds.max(axis=2) | |
bpe_text = self.bpe_decode(bpe_preds_idx[:, 1:], | |
bpe_preds_prob[:, 1:]) | |
wp_preds = wp_preds.detach() #.cpu().numpy() | |
wp_preds_prob, wp_preds_idx = wp_preds.max(-1) | |
wp_text = self.wp_decode(wp_preds_idx[:, 1:], wp_preds_prob[:, 1:]) | |
final_text = self.final_decode(char_text, bpe_text, wp_text) | |
return char_text, bpe_text, wp_text, final_text, label | |
def add_special_char(self, dict_character): | |
dict_character = self.list_token + dict_character | |
return dict_character | |
def final_decode(self, char_text, bpe_text, wp_text): | |
result_list = [] | |
for (char_pred, | |
char_pred_conf), (bpe_pred, | |
bpe_pred_conf), (wp_pred, wp_pred_conf) in zip( | |
char_text, bpe_text, wp_text): | |
final_text = char_pred | |
final_prob = char_pred_conf | |
if bpe_pred_conf > final_prob: | |
final_text = bpe_pred | |
final_prob = bpe_pred_conf | |
if wp_pred_conf > final_prob: | |
final_text = wp_pred | |
final_prob = wp_pred_conf | |
result_list.append((final_text, final_prob)) | |
return result_list | |
def char_decode(self, text_index, text_prob=None): | |
""" convert text-index into text-label. """ | |
result_list = [] | |
batch_size = len(text_index) | |
for batch_idx in range(batch_size): | |
char_list = [] | |
conf_list = 1.0 | |
for idx in range(len(text_index[batch_idx])): | |
try: | |
char_idx = self.character[int(text_index[batch_idx][idx])] | |
except: | |
continue | |
if text_prob is not None: | |
conf_list *= text_prob[batch_idx][idx] | |
if char_idx == self.EOS: # end | |
break | |
if char_idx == self.PAD: | |
continue | |
char_list.append(char_idx) | |
text = ''.join(char_list) | |
result_list.append((text, conf_list)) | |
return result_list | |
def bpe_decode(self, text_index, text_prob): | |
""" convert text-index into text-label. """ | |
result_list = [] | |
for text, probs in zip(text_index, text_prob): | |
text_decoded = [] | |
conf_list = 1.0 | |
for bpeindx, prob in zip(text, probs): | |
tokenstr = self.bpe_tokenizer.decode([bpeindx]) | |
if tokenstr == '#': | |
break | |
text_decoded.append(tokenstr) | |
conf_list *= prob | |
text = ''.join(text_decoded) | |
result_list.append((text, conf_list)) | |
return result_list | |
def wp_decode(self, text_index, text_prob=None): | |
""" convert text-index into text-label. """ | |
result_list = [] | |
for batch_idx, text in enumerate(text_index): | |
wp_pred = self.wp_tokenizer.decode(text) | |
wp_pred_EOS = wp_pred.find('[SEP]') | |
wp_pred = wp_pred[:wp_pred_EOS] | |
if text_prob is not None: | |
try: | |
# print(text.cpu().tolist()) | |
wp_pred_EOS_index = text.cpu().tolist().index(102) + 1 | |
except: | |
wp_pred_EOS_index = -1 | |
wp_pred_max_prob = text_prob[batch_idx][:wp_pred_EOS_index] | |
try: | |
wp_confidence_score = wp_pred_max_prob.cumprod( | |
dim=0)[-1].cpu().numpy().sum() | |
except: | |
wp_confidence_score = 0.0 | |
else: | |
wp_confidence_score = 1.0 | |
result_list.append((wp_pred, wp_confidence_score)) | |
return result_list | |