OpenOCR-Demo / openrec /preprocess /mgp_label_encode.py
topdu's picture
openocr demo
29f689c
raw
history blame
3.44 kB
'''
This code is refer from:
https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
'''
import numpy as np
from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
class MGPLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
SPACE = '[s]'
GO = '[GO]'
list_token = [GO, SPACE]
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
only_char=False,
**kwargs):
super(MGPLabelEncode,
self).__init__(max_text_length, character_dict_path,
use_space_char)
# character (str): set of the possible characters.
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
self.batch_max_length = max_text_length + len(self.list_token)
self.only_char = only_char
if not only_char:
# transformers==4.2.1
from transformers import BertTokenizer, GPT2Tokenizer
self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.wp_tokenizer = BertTokenizer.from_pretrained(
'bert-base-uncased')
def __call__(self, data):
text = data['label']
char_text, char_len = self.encode(text)
if char_text is None:
return None
data['length'] = np.array(char_len)
data['char_label'] = np.array(char_text)
if self.only_char:
return data
bpe_text = self.bpe_encode(text)
if bpe_text is None:
return None
wp_text = self.wp_encode(text)
data['bpe_label'] = np.array(bpe_text)
data['wp_label'] = wp_text
return data
def add_special_char(self, dict_character):
dict_character = self.list_token + dict_character
return dict_character
def encode(self, text):
""" convert text-label into text-index.
"""
if len(text) == 0:
return None, None
if self.lower:
text = text.lower()
length = len(text)
text = [self.GO] + list(text) + [self.SPACE]
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
if len(text_list) == 0 or len(text_list) > self.batch_max_length:
return None, None
text_list = text_list + [self.dict[self.GO]
] * (self.batch_max_length - len(text_list))
return text_list, length
def bpe_encode(self, text):
if len(text) == 0:
return None
token = self.bpe_tokenizer(text)['input_ids']
text_list = [1] + token + [2]
if len(text_list) == 0 or len(text_list) > self.batch_max_length:
return None
text_list = text_list + [self.dict[self.GO]
] * (self.batch_max_length - len(text_list))
return text_list
def wp_encode(self, text):
wp_target = self.wp_tokenizer([text],
padding='max_length',
max_length=self.batch_max_length,
truncation=True,
return_tensors='np')
return wp_target['input_ids'][0]