OpenOCR-Demo / openrec /postprocess /mgp_postprocess.py
topdu's picture
openocr demo
29f689c
raw
history blame
5.51 kB
from .ctc_postprocess import BaseRecLabelDecode
class MPGLabelDecode(BaseRecLabelDecode):
"""Convert between text-label and text-index."""
SPACE = '[s]'
GO = '[GO]'
list_token = [GO, SPACE]
def __init__(self,
character_dict_path=None,
use_space_char=False,
only_char=False,
**kwargs):
super(MPGLabelDecode, self).__init__(character_dict_path,
use_space_char)
self.only_char = only_char
self.EOS = '[s]'
self.PAD = '[GO]'
if not only_char:
# transformers==4.2.1
from transformers import BertTokenizer, GPT2Tokenizer
self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.wp_tokenizer = BertTokenizer.from_pretrained(
'bert-base-uncased')
def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, list):
char_preds = preds[0].detach().cpu().numpy()
else:
char_preds = preds.detach().cpu().numpy()
preds_idx = char_preds.argmax(axis=2)
preds_prob = char_preds.max(axis=2)
char_text = self.char_decode(preds_idx[:, 1:], preds_prob[:, 1:])
if batch is None:
return char_text
label = batch[1]
label = self.char_decode(label[:, 1:].detach().cpu().numpy())
if self.only_char:
return char_text, label
else:
bpe_preds = preds[1].detach().cpu().numpy()
wp_preds = preds[2]
bpe_preds_idx = bpe_preds.argmax(axis=2)
bpe_preds_prob = bpe_preds.max(axis=2)
bpe_text = self.bpe_decode(bpe_preds_idx[:, 1:],
bpe_preds_prob[:, 1:])
wp_preds = wp_preds.detach() #.cpu().numpy()
wp_preds_prob, wp_preds_idx = wp_preds.max(-1)
wp_text = self.wp_decode(wp_preds_idx[:, 1:], wp_preds_prob[:, 1:])
final_text = self.final_decode(char_text, bpe_text, wp_text)
return char_text, bpe_text, wp_text, final_text, label
def add_special_char(self, dict_character):
dict_character = self.list_token + dict_character
return dict_character
def final_decode(self, char_text, bpe_text, wp_text):
result_list = []
for (char_pred,
char_pred_conf), (bpe_pred,
bpe_pred_conf), (wp_pred, wp_pred_conf) in zip(
char_text, bpe_text, wp_text):
final_text = char_pred
final_prob = char_pred_conf
if bpe_pred_conf > final_prob:
final_text = bpe_pred
final_prob = bpe_pred_conf
if wp_pred_conf > final_prob:
final_text = wp_pred
final_prob = wp_pred_conf
result_list.append((final_text, final_prob))
return result_list
def char_decode(self, text_index, text_prob=None):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = 1.0
for idx in range(len(text_index[batch_idx])):
try:
char_idx = self.character[int(text_index[batch_idx][idx])]
except:
continue
if text_prob is not None:
conf_list *= text_prob[batch_idx][idx]
if char_idx == self.EOS: # end
break
if char_idx == self.PAD:
continue
char_list.append(char_idx)
text = ''.join(char_list)
result_list.append((text, conf_list))
return result_list
def bpe_decode(self, text_index, text_prob):
""" convert text-index into text-label. """
result_list = []
for text, probs in zip(text_index, text_prob):
text_decoded = []
conf_list = 1.0
for bpeindx, prob in zip(text, probs):
tokenstr = self.bpe_tokenizer.decode([bpeindx])
if tokenstr == '#':
break
text_decoded.append(tokenstr)
conf_list *= prob
text = ''.join(text_decoded)
result_list.append((text, conf_list))
return result_list
def wp_decode(self, text_index, text_prob=None):
""" convert text-index into text-label. """
result_list = []
for batch_idx, text in enumerate(text_index):
wp_pred = self.wp_tokenizer.decode(text)
wp_pred_EOS = wp_pred.find('[SEP]')
wp_pred = wp_pred[:wp_pred_EOS]
if text_prob is not None:
try:
# print(text.cpu().tolist())
wp_pred_EOS_index = text.cpu().tolist().index(102) + 1
except:
wp_pred_EOS_index = -1
wp_pred_max_prob = text_prob[batch_idx][:wp_pred_EOS_index]
try:
wp_confidence_score = wp_pred_max_prob.cumprod(
dim=0)[-1].cpu().numpy().sum()
except:
wp_confidence_score = 0.0
else:
wp_confidence_score = 1.0
result_list.append((wp_pred, wp_confidence_score))
return result_list